60

行列の乗算を行う最速の方法を見つけようとしていて、3 つの異なる方法を試しました。

  • 純粋な python 実装: ここで驚きはありません。
  • を使用したナンピーな実装numpy.dot(a, b)
  • ctypesPython のモジュールを使用した C とのインターフェイス。

これは、共有ライブラリに変換される C コードです。

#include <stdio.h>
#include <stdlib.h>

void matmult(float* a, float* b, float* c, int n) {
    int i = 0;
    int j = 0;
    int k = 0;

    /*float* c = malloc(nay * sizeof(float));*/

    for (i = 0; i < n; i++) {
        for (j = 0; j < n; j++) {
            int sub = 0;
            for (k = 0; k < n; k++) {
                sub = sub + a[i * n + k] * b[k * n + j];
            }
            c[i * n + j] = sub;
        }
    }
    return ;
}

そして、それを呼び出す Python コード:

def C_mat_mult(a, b):
    libmatmult = ctypes.CDLL("./matmult.so")

    dima = len(a) * len(a)
    dimb = len(b) * len(b)

    array_a = ctypes.c_float * dima
    array_b = ctypes.c_float * dimb
    array_c = ctypes.c_float * dima

    suma = array_a()
    sumb = array_b()
    sumc = array_c()

    inda = 0
    for i in range(0, len(a)):
        for j in range(0, len(a[i])):
            suma[inda] = a[i][j]
            inda = inda + 1
        indb = 0
    for i in range(0, len(b)):
        for j in range(0, len(b[i])):
            sumb[indb] = b[i][j]
            indb = indb + 1

    libmatmult.matmult(ctypes.byref(suma), ctypes.byref(sumb), ctypes.byref(sumc), 2);

    res = numpy.zeros([len(a), len(a)])
    indc = 0
    for i in range(0, len(sumc)):
        res[indc][i % len(a)] = sumc[i]
        if i % len(a) == len(a) - 1:
            indc = indc + 1

    return res

C を使用したバージョンの方が高速だったと思いますが、負けていたでしょう。以下は私のベンチマークで、私のやり方が間違っていたのか、numpy馬鹿げた速さだったのかを示しているようです:

基準

numpyバージョンがバージョンよりも速い理由を理解したいのctypesですが、純粋なPythonの実装については話していません。それは明らかなことだからです。

4

6 に答える 6

37

NumPyは、行列の乗算に高度に最適化され、注意深く調整されたBLASメソッドを使用します(ATLASも参照)。この場合の特定の関数はGEMM(一般的な行列乗算用)です。dgemm.f(Netlibにあります)を検索してオリジナルを検索できます。

ちなみに、最適化はコンパイラの最適化を超えています。上記で、PhilipはCoppersmith–Winogradについて言及しました。私の記憶が正しければ、これはATLASでの行列乗算のほとんどの場合に使用されるアルゴリズムです(ただし、コメント提供者は、Strassenのアルゴリズムである可能性があると述べています)。

言い換えれば、あなたのmatmultアルゴリズムは簡単な実装です。同じことをするより速い方法があります。

于 2012-08-20T02:48:41.520 に答える
28

私はNumpyにあまり精通していませんが、ソースはGithubにあります。ドット積の一部はhttps://github.com/numpy/numpy/blob/master/numpy/core/src/multiarray/arraytypes.c.srcに実装されており、それぞれに特定のC実装に変換されると思います。データ・タイプ。例えば:

/**begin repeat
 *
 * #name = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
 * LONG, ULONG, LONGLONG, ULONGLONG,
 * FLOAT, DOUBLE, LONGDOUBLE,
 * DATETIME, TIMEDELTA#
 * #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
 * npy_long, npy_ulong, npy_longlong, npy_ulonglong,
 * npy_float, npy_double, npy_longdouble,
 * npy_datetime, npy_timedelta#
 * #out = npy_long, npy_ulong, npy_long, npy_ulong, npy_long, npy_ulong,
 * npy_long, npy_ulong, npy_longlong, npy_ulonglong,
 * npy_float, npy_double, npy_longdouble,
 * npy_datetime, npy_timedelta#
 */
static void
@name@_dot(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp n,
           void *NPY_UNUSED(ignore))
{
    @out@ tmp = (@out@)0;
    npy_intp i;

    for (i = 0; i < n; i++, ip1 += is1, ip2 += is2) {
        tmp += (@out@)(*((@type@ *)ip1)) *
               (@out@)(*((@type@ *)ip2));
    }
    *((@type@ *)op) = (@type@) tmp;
}
/**end repeat**/

これは、1次元の内積、つまりベクトルを計算しているように見えます。数分間のGithubブラウジングでは、マトリックスのソースを見つけることができませんでしたがFLOAT_dot、結果マトリックスの要素ごとに1つの呼び出しを使用する可能性があります。つまり、この関数のループは最も内側のループに対応します。

それらの違いの1つは、「ストライド」(入力内の連続する要素間の違い)が、関数を呼び出す前に1回明示的に計算されることです。あなたの場合、ストライドはなく、各入力のオフセットは毎回計算されますa[i * n + k]。Numpyストライドに似たものに最適化する優れたコンパイラーを期待していましたが、ステップが一定であること(または最適化されていないこと)を証明できない可能性があります。

Numpyは、この関数を呼び出す高レベルのコードのキャッシュ効果を使って何か賢いことをしている可能性もあります。一般的なトリックは、各行が連続しているか、各列が連続しているかを考え、最初に各連続部分を反復処理することです。完全に最適化するのは難しいようです。各ドット積について、1つの入力行列を行でトラバースし、もう1つを列でトラバースする必要があります(たまたま異なるメジャー順序で格納されている場合を除く)。しかし、少なくとも結果要素に対してはそれを行うことができます。

Numpyには、さまざまな基本的な実装から「ドット」を含む特定の操作の実装を選択するためのコードも含まれています。たとえば、BLASライブラリを使用できます。上記の議論から、CBLASが使用されているように聞こえます。これはFortranからCに変換されました。テストで使用された実装は、http: //www.netlib.org/clapack/cblas/sdot.cにあるものだと思います。

このプログラムは、別のマシンが読み取るためにマシンによって作成されたことに注意してください。しかし、下部には、展開されたループを使用して一度に5つの要素を処理していることがわかります。

for (i = mp1; i <= *n; i += 5) {
stemp = stemp + SX(i) * SY(i) + SX(i + 1) * SY(i + 1) + SX(i + 2) * 
    SY(i + 2) + SX(i + 3) * SY(i + 3) + SX(i + 4) * SY(i + 4);
}

この展開要因は、いくつかのプロファイリング後に選択された可能性があります。ただし、理論上の利点の1つは、各分岐点間でより多くの算術演算が実行され、コンパイラとCPUが、可能な限り多くの命令パイプラインを取得するためにそれらを最適にスケジュールする方法についてより多くの選択肢があることです。

于 2012-05-04T05:01:05.363 に答える
10

特定の機能を実装するために使用される言語は、それ自体ではパフォーマンスの悪い尺度です。多くの場合、より適切なアルゴリズムを使用することが決定要因になります。

あなたの場合、O(n ^ 3)である学校で教えられているように、行列の乗算に単純なアプローチを使用しています。ただし、特定の種類の行列、たとえば正方行列、予備行列などについては、はるかにうまく処理できます。

高速な行列乗算の出発点として、Coppersmith–Winograd アルゴリズム(O(n^2.3737) の正方行列乗算) を参照してください。また、さらに高速な方法へのいくつかのポインターをリストした「参考資料」セクションも参照してください。


驚くべきパフォーマンス向上のより素朴な例については、fast を書き、strlen()それを glibc 実装と比較してみてください。うまくいかない場合は、glibc のstrlen()ソースを読んでください。かなり良いコメントがあります。

于 2012-05-04T06:30:48.547 に答える
5

NumPyを書いた人々は、明らかに彼らが何をしているのかを知っています。

行列の乗算を最適化する方法はたくさんあります。たとえば、マトリックスをトラバースする順序は、パフォーマンスに影響するメモリアクセスパターンに影響します。
SSEの適切な使用は、NumPyがおそらく採用している最適化のもう1つの方法です。
NumPyの開発者が知っている方法と私が知らない方法は他にもあるかもしれません。

ところで、最適化を使用してCコードをコンパイルしましたか?

Cに対して次の最適化を試すことができます。これは並行して機能し、NumPyは同じ線に沿って何かを行うと思います。
注:偶数サイズでのみ機能します。余分な作業を行うことで、この制限を取り除き、パフォーマンスの向上を維持できます。

for (i = 0; i < n; i++) {
        for (j = 0; j < n; j+=2) {
            int sub1 = 0, sub2 = 0;
            for (k = 0; k < n; k++) {
                sub1 = sub1 + a[i * n + k] * b[k * n + j];
                sub1 = sub1 + a[i * n + k] * b[k * n + j + 1];
            }
            c[i * n + j]     = sub;
            c[i * n + j + 1] = sub;
        }
    }
}
于 2012-05-04T04:26:52.903 に答える
5

Numpy は高度に最適化されたコードでもあります。その一部についてのエッセイが本Beautiful Codeにあります。

ctypes は、C から Python への動的な変換を経なければならず、オーバーヘッドがいくらか追加されます。Numpy では、ほとんどの行列演算は完全に内部で行われます。

于 2012-05-04T04:33:13.567 に答える
2

数値コードでの Fortran の速度の優位性について与えられた最も一般的な理由は、言語がエイリアシングを検出しやすくすることです。コンパイラは、乗算される行列が同じメモリを共有していないことを認識できるため、キャッシングの改善に役立ちます (いいえ結果がすぐに「共有」メモリに書き戻されることを確認する必要があります)。これが、C99 がrestrictを導入した理由です。

ただし、この場合、numpy コードも、C コードでは使用できない特別な命令を使用しているのだろうか(違いが特に大きいように見えるため)。

于 2012-05-04T04:31:36.193 に答える