Find centralized, trusted content and collaborate around the technologies you use most.
Teams
Q&A for work
Connect and share knowledge within a single location that is structured and easy to search.
私は JAX を使用しており、次のような操作を実行したい
@jax.jit def fun(x, index): x[:index] = other_fun(x[:index]) return x
これは では実行できませんjit。またはでこれを行う方法はありますjax.opsかjax.lax?使おうと思ったのですが、同じ問題を繰り返さずjax.ops.index_update(x, idx, y)に計算する方法が見つかりません。y
jit
jax.ops
jax.lax
jax.ops.index_update(x, idx, y)
y