問題タブ [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - Mypy が 2 つの Jax 配列を追加すると numpy 配列が返されると考えるのはなぜですか?
次のファイルを検討してください。
実行mypy mypytest.py
すると、次のエラーが返されます。
何らかの理由で、2 つの s を追加するとboolsjax.numpy.ndarray
の NumPy 配列が返されると考えられています。私は何か間違ったことをしていますか?それとも、これは MyPy または Jax の型注釈のバグですか?
python - JAX/JIT と Std Numpy のパフォーマンス: どこが間違っているのでしょうか?
ここでは、いくつかの関数を受け入れて境界のセットを統合するために作成した Simpson 統合コードを使用した簡単な演習を行います。
私は CPU デバイスを使用しており、Int(f_i, a_j,b_j) i:0-199 および j:0-99 に対応する 200 x 100 の数値配列を取得します。
%timeit simps(funcN,a,b, 512)
ループあたり 1.13 秒 ± 27.4 ミリ秒 (7 回の実行の平均値 ± 標準偏差、各ループ 1 回)
次の JAX/JIT バージョンを検討してください。
2 つのコード (純粋な Numpy と JAX/JIT) が同じ結果をもたらすことを確認しました。最大相対誤差は 8.10^-16 のオーダーです。
今、私はループごとに次のタイミング933ミリ秒±51.4ミリ秒を得ました(7回の実行の平均±標準偏差、それぞれ1ループ)
これは純粋な Numpy に非常に近いものです。たまたま非常に効率的な純粋な Numpy コードを作成したことがありますか? または、間違った方法で JAX/JIT をコーディングしましたか?
(注: Google コラボ K80 GPU を使用すると、JAX/JIT のタイミングがループごとに 7.19 ミリ秒に低下し、純粋な Numpy が 1 秒/ループのレベルに維持されます)