1

バッチ行列乗算を使用して手順を書いていますが、一般的な設定ではない可能性があります。次の入力を検討しています。

# Let's say I have a list of points in R^3, from 3 distinct objects
# (so my data batch has 3 data entry)
# X: (B1+B2+B3) * 3 
X = torch.tensor([[1,1,1],[1,1,1],
                  [2,2,2],[2,2,2],[2,2,2],
                  [3,3,3],])
# To indicate which object the points are corresponding to,
# I have a list of indices (say, starting from 0):
# idx: (B1+B2+B3)
idx = torch.tensor([0,0,1,1,1,2])
# For each point from the same object, I want to multiply it to a 3x3 matrix, A_i.
# As I have 3 objects here, I have A_0, A_1, A_2.
# A: 3 x 3 x 3
A = torch.tensor([[[1,1,1],[1,1,1],[1,1,1]],
                  [[2,2,2],[2,2,2],[2,2,2]],
                  [[3,3,3],[3,3,3],[3,3,3]]])

そして、望ましい出力は次のとおりです。

out = X.unsqueeze(1).bmm(A[idx])
out = out.squeeze(1)                # just to remove excessive dimension
# out = torch.tensor([[[1,1,1]],[[1,1,1]],            # obj0 mult with A_0
                      [[2,2,2]],[[2,2,2]],[[2,2,2]],  # obj1 mult with A_1
                      [[3,3,3]],])                    # obj2 mult with A_2

実際、pytorch では非常に便利で、たった 1 行です!

ここで、この手順を改善したいと思います。A[idx]を使用してポイントごとに 1 つの行列 A_i を複製しているので、ここで torch.bmm() 関数を使用できることに注意してください (1 ポイント <-> 1 行列)。Afaik、それはA[idx] の中間表現のためにメモリを割り当てる必要があります。一般に、データ バッチに BN オブジェクトがある場合、A[idx] のサイズ = (B1+...+BN)*3*3 となり、非常に大きくなる可能性があります。

したがって、行列 A_i の複製を回避できるかどうか疑問に思っています。

Batch Mat に関して以前に寄せられたほとんどの質問を見つけました。マルチ。固定バッチ サイズのみを想定します。ここで私のものと同じ質問がされ、テンソルフローでの解決策が提供されました。ただし、ソリューションは tf.tile() を使用して実装されており、これも行列を複製しています。


要約すると、私の質問は、次のことを達成しながら、バッチ行列乗算についてです。

- dynamic batch size
    - input shape: (B1+...+BN) x 3
    - index shape: (B1+...+BN)
- memory efficiency
    - probably w/out massive replication of matrix

ここでは pytorch を使用していますが、他の実装も受け入れます。入力 (乗算する行列、A など) を他の構造体で表現することも、結果としてメモリ効率が向上する場合は受け入れます。

4

0 に答える 0