私はPytorch seq2seq チュートリアルに従っており、そのtorch.bmm
メソッドは以下のように使用されます:
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
注意の重みとエンコーダーの出力を乗算する必要がある理由を理解しています。
bmm
私がよく理解していないのは、ここでメソッド
が必要な理由です。torch.bmm
文書によると
バッチ 1 とバッチ 2 に格納されている行列のバッチ行列行列積を実行します。
batch1 と batch2 は、それぞれが同じ数の行列を含む 3 次元テンソルでなければなりません。
batch1 が (b×n×m) テンソル、batch2 が (b×m×p) テンソルの場合、out は (b×n×p) テンソルになります。