2

次のコードは機能しません。

def get_unique(arr):
    return jnp.unique(arr)

get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))

エラーメッセージは、次の使用について説明していますjnp.unique

FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()

シャープ ビットに関するドキュメントでは、内部配列の形状が引数の値に依存する場合、jit は機能しないと説明されています。これはまさにここに当てはまります。

ドキュメントによると、潜在的な回避策は静的パラメーターを指定することです。しかし、これは私の場合には当てはまりません。パラメータは、ほぼすべての関数呼び出しで変更されます。コードを、このような計算を実行する前処理ステップjnp.uniqueと、jit できる計算ステップに分割しました。

それでも聞きたいのですが、私が気付いていない回避策はありますか?

4

1 に答える 1

1

いいえ、あなたが言及した理由により、現在、jnp.unique静的でない値で使用する方法はありません。

同様のケースで、JAX は、出力の静的サイズを指定するために使用できる追加のパラメーター (たとえば、 のsizeパラメーターjax.numpy.nonzero) を追加することがありますが、現時点ではそのようなものは実装されていませんjnp.unique。それが必要な場合は、機能リクエストを提出する価値があります。

于 2021-05-30T03:49:28.003 に答える