1

次の numpy コードはまったく問題ありません。

arr = np.arange(50)
print(arr.shape) # (50,)

indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

jax に移行した後も機能します。

import jax.numpy as jnp

arr = jnp.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

それでは、numpy と jax の組み合わせを試してみましょう。

arr = np.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)

arr[indices]

これにより、次のエラーが発生します。

IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed

jax 配列を使用した numpy 配列へのインデックス付けがサポートされていない場合、それは問題ありません。しかし、エラーメッセージは間違っているようです。そして、事態はさらに混乱します。形状を少し変更すると、コードは正常に機能します。次のサンプルでは、​​(30,) から (40,) までのインデックスの形状のみを編集しました。エラーメッセージはもうありません:

arr = np.arange(50)
print(arr.shape) # (50,)

indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)

arr[indices]

CPUでjaxバージョン「0.2.12」を実行しています。ここで何が起きてるの?

4

1 に答える 1