11

私はPytorch seq2seq チュートリアルに従っており、そのtorch.bmmメソッドは以下のように使用されます:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

注意の重みとエンコーダーの出力を乗算する必要がある理由を理解しています。

bmm私がよく理解していないのは、ここでメソッド が必要な理由です。torch.bmm文書によると

バッチ 1 とバッチ 2 に格納されている行列のバッチ行列行列積を実行します。

batch1 と batch2 は、それぞれが同じ数の行列を含む 3 次元テンソルでなければなりません。

batch1 が (b×n×m) テンソル、batch2 が (b×m×p) テンソルの場合、out は (b×n×p) テンソルになります。

ここに画像の説明を入力

4

3 に答える 3