問題タブ [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票する
2 に答える
227 参照

python - Mypy が 2 つの Jax 配列を追加すると numpy 配列が返されると考えるのはなぜですか?

次のファイルを検討してください。

実行mypy mypytest.pyすると、次のエラーが返されます。

何らかの理由で、2 つの s を追加するとboolsjax.numpy.ndarrayの NumPy 配列が返されると考えられています。私は何か間違ったことをしていますか?それとも、これは MyPy または Jax の型注釈のバグですか?

0 投票する
1 に答える
370 参照

python - 微分を伴う Python のエラー メッセージ

私はジェネリック コール オプションのモンテカルロ アプローチを使用してこれらの導関数を計算しています。私は、この組み合わせ導関数 (S とシグマの両方に関して) に興味があります。アルゴリズムの微分でこれを行うと、ページの最後に表示されるエラーが表示されます。可能な解決策は何ですか?コードに関して何かを説明するために、以下のコードで「X」を計算するために使用される式を添付します。

ここに画像の説明を入力

これはエラーメッセージです:

以下のスタック トレースは、JAX 内部フレームを除外しています。上記は、発生した元の例外であり、変更されていません。


上記の例外は、次の例外の直接の原因でした。

0 投票する
1 に答える
238 参照

python - JAX/JIT と Std Numpy のパフォーマンス: どこが間違っているのでしょうか?

ここでは、いくつかの関数を受け入れて境界のセットを統合するために作成した Simpson 統合コードを使用した簡単な演習を行います。

私は CPU デバイスを使用しており、Int(f_i, a_j,b_j) i:0-199 および j:0-99 に対応する 200 x 100 の数値配列を取得します。

%timeit simps(funcN,a,b, 512)

ループあたり 1.13 秒 ± 27.4 ミリ秒 (7 回の実行の平均値 ± 標準偏差、各ループ 1 回)

次の JAX/JIT バージョンを検討してください。

2 つのコード (純粋な Numpy と JAX/JIT) が同じ結果をもたらすことを確認しました。最大相対誤差は 8.10^-16 のオーダーです。

今、私はループごとに次のタイミング933ミリ秒±51.4ミリ秒を得ました(7回の実行の平均±標準偏差、それぞれ1ループ)

これは純粋な Numpy に非常に近いものです。たまたま非常に効率的な純粋な Numpy コードを作成したことがありますか? または、間違った方法で JAX/JIT をコーディングしましたか?

(注: Google コラボ K80 GPU を使用すると、JAX/JIT のタイミングがループごとに 7.19 ミリ秒に低下し、純粋な Numpy が 1 秒/ループのレベルに維持されます)