次の numpy コードはまったく問題ありません。
arr = np.arange(50)
print(arr.shape) # (50,)
indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
jax に移行した後も機能します。
import jax.numpy as jnp
arr = jnp.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
それでは、numpy と jax の組み合わせを試してみましょう。
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
これにより、次のエラーが発生します。
IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed
jax 配列を使用した numpy 配列へのインデックス付けがサポートされていない場合、それは問題ありません。しかし、エラーメッセージは間違っているようです。そして、事態はさらに混乱します。形状を少し変更すると、コードは正常に機能します。次のサンプルでは、(30,) から (40,) までのインデックスの形状のみを編集しました。エラーメッセージはもうありません:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)
arr[indices]
CPUでjaxバージョン「0.2.12」を実行しています。ここで何が起きてるの?