2

私は関数を持ってcompute(x)xますjnp.ndarrayvmap今、私はそれを配列のバッチを取る関数に変換しx[i]、それjitを高速化するために使用したいと考えています。compute(x)次のようなものです:

def compute(x):
    # ... some code
    y = very_expensive_function(x)
    return y

ただし、各配列x[i]の長さは異なります。Nこの問題は、配列に末尾のゼロをパディングして、すべて同じ長さになりvmap(compute)、 shape のバッチに適用できるようにすることで簡単に回避できます(batch_size, N)

ただし、そうすると、very_expensive_function()各配列の末尾のゼロに対しても呼び出されることになりますx[i]。とに干渉することなく、 のスライスでのみ呼び出されるものcompute()を変更する方法はありますか?very_expensive_function()xvmapjit

4

1 に答える 1