1

次のファイルを検討してください。

import jax.numpy as jnp

def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    return a + b

実行mypy mypytest.pyすると、次のエラーが返されます。

mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")

何らかの理由で、2 つの s を追加するとboolsjax.numpy.ndarrayの NumPy 配列が返されると考えられています。私は何か間違ったことをしていますか?それとも、これは MyPy または Jax の型注釈のバグですか?

4

2 に答える 2