私は計算するための高速な方法を探しています
(1:N)'*(1:N)
問題の対称性により、実際に乗算と加算を行うのは無駄になるように感じます。
なぜこれをやりたいのかという問題は本当に重要です。
理論的な意味では、他の回答で提案されている三角形のアプローチにより、操作が節約されます。@jgmaoの答えは、乗算を減らすのに特に興味深いものです。
実際には、CPU 操作の数は、高速なコードを記述するときに最小化するための指標ではなくなりました。CPU 操作が非常に少ない場合、メモリ帯域幅が支配的であるため、キャッシュを意識したアクセス パターンを調整することで、これを高速化できます。行列乗算コードは非常に効率的に実装されます。これは非常に一般的な操作であり、その価値がある BLAS 数値ライブラリのすべての実装では、最適化されたアクセス パターンと SIMD 計算も使用されます。
単純な C を書き、演算数を理論上の最小値まで減らしたとしても、完全な行列乗算には勝てないでしょう。これが要約すると、操作に最もよく一致する数値プリミティブを見つけることです。
そうは言っても、DGEMM (行列乗算) よりも少し近い BLAS 操作があります。これは DSYRK、ランク k 更新と呼ばれ、正確に使用できますA'*A
。このためにずっと前に書いた MEX 関数はこちらです。私は長い間それをいじっていませんでしたが、最初に書いたときはうまくいき、実際にはストレートよりも速く実行されましたA'*A
.
/* xtrx.c: calculates x'*x taking advantage of the symmetry.
Peter Boettcher <email removed>
Last modified: <Thu Jan 23 13:53:02 2003> */
#include "mex.h"
const double one = 1;
const double zero = 0;
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *x, *z;
int i, j, mrows, ncols;
if(nrhs!=1) mexErrMsgTxt("One input required.");
x = mxGetPr(prhs[0]);
mrows = mxGetM(prhs[0]);
ncols = mxGetN(prhs[0]);
plhs[0] = mxCreateDoubleMatrix(ncols,ncols, mxREAL);
z = mxGetPr(plhs[0]);
/* Call the FORTRAN BLAS routine for rank k update */
dsyrk_("U", "T", &ncols, &mrows, &one, x, &mrows, &zero, z, &ncols);
/* Result is in the upper triangle. Copy it down the lower part */
for(i=0; i<ncols; i++)
for(j=i+1; j<ncols; j++)
z[i*ncols + j] = z[j*ncols + i];
}
これは (1:N).'*(1:N) よりも 3 倍高速です (結果が許容範囲内であれば) (数値がの代わりにint32
使用するのに十分小さい場合はさらに高速です)。int16
int32
N = 1000;
aux = int32(1:N);
result = bsxfun(@times,aux.',aux);
ベンチマーク:
>> N = 1000; aux = int32(1:N); tic, for count = 1:1e2, bsxfun(@times,aux.',aux); end, toc
Elapsed time is 0.734992 seconds.
>> N = 1000; aux = 1:N; tic, for count = 1:1e2, aux.'*aux; end, toc
Elapsed time is 2.281784 seconds.
aux.'*aux
には使用できませんのでご注意くださいaux = int32(1:N)
。
@DanielE.Shub で指摘されているように、結果がdouble
行列として必要な場合は、最終的なキャストを行う必要があり、その場合、ゲインは非常に小さくなります。
>> N = 1000; aux = int32(1:N); tic, for count = 1:1e2, double(bsxfun(@times,aux.',aux)); end, toc
Elapsed time is 2.173059 seconds.
入力の特別な順序構造のため、N=4 の場合を考えてください。
(1:4)'*(1:4) = [1 2 3 4
2 4 6 8
3 6 9 12
4 8 12 16]
1行目はちょうど(1:N)であり、2行目(j = 2)行から、この行の値は前の行(j = 1)に(1:N)を加えたものであることがわかります。ですから 1. 掛け算はあまりしません。代わりに、N*N 回の加算によって生成できます。2. 出力は対称であるため、出力行列の半分だけを計算する必要があります。したがって、合計計算は (N-1)+(N-2)+...+1 = N^2 / 2 回の加算です。