3

私は JAX を使用しており、次のような操作を実行したい

@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x

これは では実行できませんjit。またはでこれを行う方法はありますjax.opsjax.lax?使おうと思ったのですが、同じ問題を繰り返さずjax.ops.index_update(x, idx, y)に計算する方法が見つかりません。y

4

2 に答える 2