私は関数を持ってcompute(x)
いx
ますjnp.ndarray
。vmap
今、私はそれを配列のバッチを取る関数に変換しx[i]
、それjit
を高速化するために使用したいと考えています。compute(x)
次のようなものです:
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
ただし、各配列x[i]
の長さは異なります。N
この問題は、配列に末尾のゼロをパディングして、すべて同じ長さになりvmap(compute)
、 shape のバッチに適用できるようにすることで簡単に回避できます(batch_size, N)
。
ただし、そうすると、very_expensive_function()
各配列の末尾のゼロに対しても呼び出されることになりますx[i]
。とに干渉することなく、 のスライスでのみ呼び出されるものcompute()
を変更する方法はありますか?very_expensive_function()
x
vmap
jit