編集:より正確な時間を与えるためにテストを改善しました。また、展開されたバージョンを最適化しましたが、これは最初に持っていたものよりもはるかに優れていますが、サイズを大きくすると行列の乗算ははるかに高速になります。
EDIT2: JIT コンパイラが展開された関数で動作していることを確認するために、生成された関数を M ファイルとして書き込むようにコードを変更しました。また、TIMEIT に関数ハンドルを渡すことで両方のメソッドが評価されるため、比較は公平に行われます。timeit(@myfunc)
あなたのアプローチが合理的なサイズの行列乗算よりも高速であるとは確信していません。それでは、2 つの方法を比較してみましょう。
Symbolic Math Toolbox を使用して、次の方程式の「展開」形式を取得していますx'*A*x
(20x20 行列と 20x1 ベクトルを手で掛けてみてください!):
function f = buildUnrolledFunction(N)
% avoid regenerating files, CCODE below can be really slow!
fname = sprintf('f%d',N);
if exist([fname '.m'], 'file')
f = str2func(fname);
return
end
% construct symbolic vector/matrix of the specified size
x = sym('x', [N 1]);
A = sym('A', [N N]);
% work out the expanded form of the matrix-multiplication
% and convert it to a string
s = ccode(expand(x.'*A*x)); % instead of char(.) to avoid x^2
% a bit of RegExp to fix the notation of the variable names
% also convert indexing into linear indices: A(3,3) into A(9)
s = regexprep(regexprep(s, '^.*=\s+', ''), ';$', '');
s = regexprep(regexprep(s, 'x(\d+)', 'x($1)'), 'A(\d+)_(\d+)', ...
'A(${ int2str(sub2ind([N N],str2num($1),str2num($2))) })');
% build an M-function from the string, and write it to file
fid = fopen([fname '.m'], 'wt');
fprintf(fid, 'function v = %s(A,x)\nv = %s;\nend\n', fname, s);
fclose(fid);
% rehash path and return a function handle
rehash
clear(fname)
f = str2func(fname);
end
累乗を避けることで、生成された関数を最適化しようとしました (私たちは を好みx*x
ますx^2
)。A(9)
また、添字を (の代わりに)線形インデックスに変換しましたA(3,3)
。したがって、n=3
あなたが持っていたのと同じ方程式が得られます。
>> s
s =
A(1)*(x(1)*x(1)) + A(5)*(x(2)*x(2)) + A(9)*(x(3)*x(3)) +
A(4)*x(1)*x(2) + A(7)*x(1)*x(3) + A(2)*x(1)*x(2) +
A(8)*x(2)*x(3) + A(3)*x(1)*x(3) + A(6)*x(2)*x(3)
M 関数を構築する上記の方法を考慮して、さまざまなサイズについて評価し、行列乗算形式と比較します (関数呼び出しのオーバーヘッドを考慮して別の関数に入れました)。より正確なタイミングを取得する代わりに、 TIMEIT関数を使用しています。tic/toc
また、公平に比較するために、各メソッドは、必要なすべての変数を入力引数として渡す M ファイル関数として実装されています。
function results = testMatrixMultVsUnrolled()
% vector/matrix size
N_vec = 2:50;
results = zeros(numel(N_vec),3);
for ii = 1:numel(N_vec);
% some random data
N = N_vec(ii);
x = rand(N,1); A = rand(N,N);
% matrix multiplication
f = @matMult;
results(ii,1) = timeit(@() feval(f, A,x));
% unrolled equation
f = buildUnrolledFunction(N);
results(ii,2) = timeit(@() feval(f, A,x));
% check result
results(ii,3) = norm(matMult(A,x) - f(A,x));
end
% display results
fprintf('N = %2d: mtimes = %.6f ms, unroll = %.6f ms [error = %g]\n', ...
[N_vec(:) results(:,1:2)*1e3 results(:,3)]')
plot(N_vec, results(:,1:2)*1e3, 'LineWidth',2)
xlabel('size (N)'), ylabel('timing [msec]'), grid on
legend({'mtimes','unrolled'})
title('Matrix multiplication: $$x^\mathsf{T}Ax$$', ...
'Interpreter','latex', 'FontSize',14)
end
function v = matMult(A,x)
v = x.' * A * x;
end
結果:

N = 2: mtimes = 0.008816 ms, unroll = 0.006793 ms [error = 0]
N = 3: mtimes = 0.008957 ms, unroll = 0.007554 ms [error = 0]
N = 4: mtimes = 0.009025 ms, unroll = 0.008261 ms [error = 4.44089e-16]
N = 5: mtimes = 0.009075 ms, unroll = 0.008658 ms [error = 0]
N = 6: mtimes = 0.009003 ms, unroll = 0.008689 ms [error = 8.88178e-16]
N = 7: mtimes = 0.009234 ms, unroll = 0.009087 ms [error = 1.77636e-15]
N = 8: mtimes = 0.008575 ms, unroll = 0.009744 ms [error = 8.88178e-16]
N = 9: mtimes = 0.008601 ms, unroll = 0.011948 ms [error = 0]
N = 10: mtimes = 0.009077 ms, unroll = 0.014052 ms [error = 0]
N = 11: mtimes = 0.009339 ms, unroll = 0.015358 ms [error = 3.55271e-15]
N = 12: mtimes = 0.009271 ms, unroll = 0.018494 ms [error = 3.55271e-15]
N = 13: mtimes = 0.009166 ms, unroll = 0.020238 ms [error = 0]
N = 14: mtimes = 0.009204 ms, unroll = 0.023326 ms [error = 7.10543e-15]
N = 15: mtimes = 0.009396 ms, unroll = 0.024767 ms [error = 3.55271e-15]
N = 16: mtimes = 0.009193 ms, unroll = 0.027294 ms [error = 2.4869e-14]
N = 17: mtimes = 0.009182 ms, unroll = 0.029698 ms [error = 2.13163e-14]
N = 18: mtimes = 0.009330 ms, unroll = 0.033295 ms [error = 7.10543e-15]
N = 19: mtimes = 0.009411 ms, unroll = 0.152308 ms [error = 7.10543e-15]
N = 20: mtimes = 0.009366 ms, unroll = 0.167336 ms [error = 7.10543e-15]
N = 21: mtimes = 0.009335 ms, unroll = 0.183371 ms [error = 0]
N = 22: mtimes = 0.009349 ms, unroll = 0.200859 ms [error = 7.10543e-14]
N = 23: mtimes = 0.009411 ms, unroll = 0.218477 ms [error = 8.52651e-14]
N = 24: mtimes = 0.009307 ms, unroll = 0.235668 ms [error = 4.26326e-14]
N = 25: mtimes = 0.009425 ms, unroll = 0.256491 ms [error = 1.13687e-13]
N = 26: mtimes = 0.009392 ms, unroll = 0.274879 ms [error = 7.10543e-15]
N = 27: mtimes = 0.009515 ms, unroll = 0.296795 ms [error = 2.84217e-14]
N = 28: mtimes = 0.009567 ms, unroll = 0.319032 ms [error = 5.68434e-14]
N = 29: mtimes = 0.009548 ms, unroll = 0.339517 ms [error = 3.12639e-13]
N = 30: mtimes = 0.009617 ms, unroll = 0.361897 ms [error = 1.7053e-13]
N = 31: mtimes = 0.009672 ms, unroll = 0.387270 ms [error = 0]
N = 32: mtimes = 0.009629 ms, unroll = 0.410932 ms [error = 1.42109e-13]
N = 33: mtimes = 0.009605 ms, unroll = 0.434452 ms [error = 1.42109e-13]
N = 34: mtimes = 0.009534 ms, unroll = 0.462961 ms [error = 0]
N = 35: mtimes = 0.009696 ms, unroll = 0.489474 ms [error = 5.68434e-14]
N = 36: mtimes = 0.009691 ms, unroll = 0.512198 ms [error = 8.52651e-14]
N = 37: mtimes = 0.009671 ms, unroll = 0.544485 ms [error = 5.68434e-14]
N = 38: mtimes = 0.009710 ms, unroll = 0.573564 ms [error = 8.52651e-14]
N = 39: mtimes = 0.009946 ms, unroll = 0.604567 ms [error = 3.41061e-13]
N = 40: mtimes = 0.009735 ms, unroll = 0.636640 ms [error = 3.12639e-13]
N = 41: mtimes = 0.009858 ms, unroll = 0.665719 ms [error = 5.40012e-13]
N = 42: mtimes = 0.009876 ms, unroll = 0.697364 ms [error = 0]
N = 43: mtimes = 0.009956 ms, unroll = 0.730506 ms [error = 2.55795e-13]
N = 44: mtimes = 0.009897 ms, unroll = 0.765358 ms [error = 4.26326e-13]
N = 45: mtimes = 0.009991 ms, unroll = 0.800424 ms [error = 0]
N = 46: mtimes = 0.009956 ms, unroll = 0.829717 ms [error = 2.27374e-13]
N = 47: mtimes = 0.010210 ms, unroll = 0.865424 ms [error = 2.84217e-13]
N = 48: mtimes = 0.010022 ms, unroll = 0.907974 ms [error = 3.97904e-13]
N = 49: mtimes = 0.010098 ms, unroll = 0.944536 ms [error = 5.68434e-13]
N = 50: mtimes = 0.010153 ms, unroll = 0.984486 ms [error = 4.54747e-13]
サイズが小さい場合、この 2 つの方法のパフォーマンスは多少似ています。N<7
拡張バージョンの場合mtimes
は . 小さなサイズを超えると、行列の乗算は桁違いに速くなります。
これは驚くべきことではありません。式だけN=20
では恐ろしく長く、400 の用語を追加する必要があります。MATLAB 言語は解釈されるため、これが非常に効率的であるとは思えません。
ここで、コードをインラインで直接埋め込む場合と比較して、外部関数を呼び出す場合のオーバーヘッドがあることに同意しますが、そのようなアプローチがどれほど実用的かということです。のような小さなサイズでもN=20
、生成される行は 7000 文字を超えます。また、行が長いために MATLAB エディターの動作が遅くなることに気付きました :)
しかも、 前後でアドバンテージはすぐになくなりN>10
ます。@DennisJaheruddinが提案したものと同様に、埋め込みコード/明示的に記述されたものと行列乗算を比較しました。結果: _
N=3:
Elapsed time is 0.062295 seconds. % unroll
Elapsed time is 1.117962 seconds. % mtimes
N=12:
Elapsed time is 1.024837 seconds. % unroll
Elapsed time is 1.126147 seconds. % mtimes
N=19:
Elapsed time is 140.915138 seconds. % unroll
Elapsed time is 1.305382 seconds. % mtimes
...そして、展開されたバージョンではさらに悪化します。前に述べたように、MATLAB は解釈されるため、コードを解析するコストがこのような巨大なファイルに現れ始めます。
私の見方では、100 万回の反復を行った後、せいぜい 1 秒しか得られませんでした。これは、はるかに読みやすく簡潔なv=x'*A*x
. したがって、行列の乗算などの既に最適化された操作に焦点を当てるのではなく、コード内に改善できる他の場所があるかもしれません。
MATLAB の行列乗算は非常に高速です(これが MATLAB の得意分野です! ) 。 十分な大きさのデータに到達すると(マルチスレッドが開始されるため)、本当に輝きます。
>> N=5000; x=rand(N,1); A=rand(N,N);
>> tic, for i=1e4, v=x.'*A*x; end, toc
Elapsed time is 0.021959 seconds.