5

デフォルトの行列乗算は次のように計算されます。

c[i,j] = sum(a[i,k] * b[k,j])

ドット積の代わりにカスタム式を使用して取得しようとしています

c[i,j] = sum(a[i,k] == b[k,j])

numpyでこれを行う効率的な方法はありますか?

4

1 に答える 1

7

ブロードキャストを使用できます:

c = sum(a[...,np.newaxis]*b[np.newaxis,...],axis=1)  # == np.dot(a,b)

c = sum(a[...,np.newaxis]==b[np.newaxis,...],axis=1)

を含めたのはnewaxisbその配列がどのように展開されるかを明確にするためです。配列に次元を追加する方法は他にもありますが (形状変更、繰り返しなど)、効果は同じです。要素ごとに要素の乗算 (または ==) を行うために同じ形状に展開し、正しい軸で合計しますab

于 2013-10-09T17:10:19.500 に答える