次のファイルを検討してください。
import jax.numpy as jnp
def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
実行mypy mypytest.py
すると、次のエラーが返されます。
mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")
何らかの理由で、2 つの s を追加するとboolsjax.numpy.ndarray
の NumPy 配列が返されると考えられています。私は何か間違ったことをしていますか?それとも、これは MyPy または Jax の型注釈のバグですか?