6

私は計算するための高速な方法を探しています

(1:N)'*(1:N)

問題の対称性により、実際に乗算と加算を行うのは無駄になるように感じます。

4

4 に答える 4

14

なぜこれをやりたいのかという問題は本当に重要です。

理論的な意味では、他の回答で提案されている三角形のアプローチにより、操作が節約されます。@jgmaoの答えは、乗算を減らすのに特に興味深いものです。

実際には、CPU 操作の数は、高速なコードを記述するときに最小化するための指標ではなくなりました。CPU 操作が非常に少ない場合、メモリ帯域幅が支配的であるため、キャッシュを意識したアクセス パターンを調整することで、これを高速化できます。行列乗算コードは非常に効率的に実装されます。これは非常に一般的な操作であり、その価値がある BLAS 数値ライブラリのすべての実装では、最適化されたアクセス パターンと SIMD 計算も使用されます。

単純な C を書き、演算数を理論上の最小値まで減らしたとしても、完全な行列乗算には勝てないでしょう。これが要約すると、操作に最もよく一致する数値プリミティブを見つけることです。

そうは言っても、DGEMM (行列乗算) よりも少し近い BLAS 操作があります。これは DSYRK、ランク k 更新と呼ばれ、正確に使用できますA'*A。このためにずっと前に書いた MEX 関数はこちらです。私は長い間それをいじっていませんでしたが、最初に書いたときはうまくいき、実際にはストレートよりも速く実行されましたA'*A.

/* xtrx.c: calculates x'*x taking advantage of the symmetry.
Peter Boettcher <email removed>
Last modified: <Thu Jan 23 13:53:02 2003> */

#include "mex.h"

const double one = 1;
const double zero = 0;

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
  double *x, *z;
  int i, j, mrows, ncols;

  if(nrhs!=1) mexErrMsgTxt("One input required.");

  x = mxGetPr(prhs[0]);
  mrows = mxGetM(prhs[0]);
  ncols = mxGetN(prhs[0]);

  plhs[0] = mxCreateDoubleMatrix(ncols,ncols, mxREAL);
  z = mxGetPr(plhs[0]);

  /* Call the FORTRAN BLAS routine for rank k update */
  dsyrk_("U", "T", &ncols, &mrows, &one, x, &mrows, &zero, z, &ncols);

  /* Result is in the upper triangle.  Copy it down the lower part */
  for(i=0; i<ncols; i++)
      for(j=i+1; j<ncols; j++)
          z[i*ncols + j] = z[j*ncols + i];
}
于 2013-10-01T20:22:43.660 に答える
5

これは (1:N).'*(1:N) よりも 3 倍高速です (結果が許容範囲内であれば) (数値がの代わりにint32使用するのに十分小さい場合はさらに高速です)。int16int32

N = 1000;
aux = int32(1:N);
result = bsxfun(@times,aux.',aux);

ベンチマーク:

>> N = 1000; aux = int32(1:N); tic, for count = 1:1e2, bsxfun(@times,aux.',aux); end, toc
Elapsed time is 0.734992 seconds.

>> N = 1000; aux = 1:N; tic, for count = 1:1e2, aux.'*aux; end, toc
Elapsed time is 2.281784 seconds.

aux.'*auxには使用できませんのでご注意くださいaux = int32(1:N)

@DanielE.Shub で指摘されているように、結果がdouble行列として必要な場合は、最終的なキャストを行う必要があり、その場合、ゲインは非常に小さくなります。

>> N = 1000; aux = int32(1:N); tic, for count = 1:1e2, double(bsxfun(@times,aux.',aux)); end, toc
Elapsed time is 2.173059 seconds.
于 2013-10-03T16:00:15.927 に答える
3

入力の特別な順序構造のため、N=4 の場合を考えてください。

(1:4)'*(1:4) = [1 2 3 4
                2 4 6 8
                3 6 9 12
                4 8 12 16]

1行目はちょうど(1:N)であり、2行目(j = 2)行から、この行の値は前の行(j = 1)に(1:N)を加えたものであることがわかります。ですから 1. 掛け算はあまりしません。代わりに、N*N 回の加算によって生成できます。2. 出力は対称であるため、出力行列の半分だけを計算する必要があります。したがって、合計計算は (N-1)+(N-2)+...+1 = N^2 / 2 回の加算です。

于 2013-10-01T19:31:46.847 に答える