次のコードは機能しません。
def get_unique(arr):
return jnp.unique(arr)
get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))
エラーメッセージは、次の使用について説明していますjnp.unique
。
FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()
シャープ ビットに関するドキュメントでは、内部配列の形状が引数の値に依存する場合、jit は機能しないと説明されています。これはまさにここに当てはまります。
ドキュメントによると、潜在的な回避策は静的パラメーターを指定することです。しかし、これは私の場合には当てはまりません。パラメータは、ほぼすべての関数呼び出しで変更されます。コードを、このような計算を実行する前処理ステップjnp.unique
と、jit できる計算ステップに分割しました。
それでも聞きたいのですが、私が気付いていない回避策はありますか?