バッチ行列乗算を使用して手順を書いていますが、一般的な設定ではない可能性があります。次の入力を検討しています。
# Let's say I have a list of points in R^3, from 3 distinct objects
# (so my data batch has 3 data entry)
# X: (B1+B2+B3) * 3
X = torch.tensor([[1,1,1],[1,1,1],
[2,2,2],[2,2,2],[2,2,2],
[3,3,3],])
# To indicate which object the points are corresponding to,
# I have a list of indices (say, starting from 0):
# idx: (B1+B2+B3)
idx = torch.tensor([0,0,1,1,1,2])
# For each point from the same object, I want to multiply it to a 3x3 matrix, A_i.
# As I have 3 objects here, I have A_0, A_1, A_2.
# A: 3 x 3 x 3
A = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]],
[[2,2,2],[2,2,2],[2,2,2]],
[[3,3,3],[3,3,3],[3,3,3]]])
そして、望ましい出力は次のとおりです。
out = X.unsqueeze(1).bmm(A[idx])
out = out.squeeze(1) # just to remove excessive dimension
# out = torch.tensor([[[1,1,1]],[[1,1,1]], # obj0 mult with A_0
[[2,2,2]],[[2,2,2]],[[2,2,2]], # obj1 mult with A_1
[[3,3,3]],]) # obj2 mult with A_2
実際、pytorch では非常に便利で、たった 1 行です!
ここで、この手順を改善したいと思います。A[idx]を使用してポイントごとに 1 つの行列 A_i を複製しているので、ここで torch.bmm() 関数を使用できることに注意してください (1 ポイント <-> 1 行列)。Afaik、それはA[idx] の中間表現のためにメモリを割り当てる必要があります。一般に、データ バッチに BN オブジェクトがある場合、A[idx] のサイズ = (B1+...+BN)*3*3 となり、非常に大きくなる可能性があります。
したがって、行列 A_i の複製を回避できるかどうか疑問に思っています。
Batch Mat に関して以前に寄せられたほとんどの質問を見つけました。マルチ。固定バッチ サイズのみを想定します。ここで私のものと同じ質問がされ、テンソルフローでの解決策が提供されました。ただし、ソリューションは tf.tile() を使用して実装されており、これも行列を複製しています。
要約すると、私の質問は、次のことを達成しながら、バッチ行列乗算についてです。
- dynamic batch size
- input shape: (B1+...+BN) x 3
- index shape: (B1+...+BN)
- memory efficiency
- probably w/out massive replication of matrix
ここでは pytorch を使用していますが、他の実装も受け入れます。入力 (乗算する行列、A など) を他の構造体で表現することも、結果としてメモリ効率が向上する場合は受け入れます。