問題タブ [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票する
1 に答える
690 参照

python - ゼロからのバイナリ クロス エントロピーの実装 - ニューラル ネットワークのトレーニングで一貫性のない結果が得られる

JAXライブラリとその小さなニューラルネットワークサブモジュール「Stax」を使用して、ニューラルネットワークを実装およびトレーニングしようとしています。このライブラリにはバイナリ クロス エントロピーの実装が付属していないため、独自に作成しました。

単純なニューラル ネットワークを実装して MNIST でトレーニングしたところ、得られた結果の一部に疑いを持ち始めました。そこで、Keras で同じセットアップを実装したところ、すぐにまったく異なる結果が得られました。同じデータで同じ方法でトレーニングされた同じモデルは、JAX では約 50% でしたが、Keras では 90% のトレーニング精度が得られました。最終的に、数値的に不安定であると思われるクロスエントロピーの単純な実装に問題の一部を突き止めました。この投稿と見つけたこのコードに従って、次の新しいバージョンを作成しました。

これは少しうまくいきます。現在、私の JAX 実装は最大 80% のトレーニング精度を達成していますが、それでも Keras の 90% よりもはるかに低いです。私が知りたいのは、何が起こっているのですか?2 つの実装が同じように動作しないのはなぜですか?

以下では、2 つの実装を 1 つのスクリプトに要約しました。このスクリプトでは、JAX と Keras で同じモデルを実装しています。両方を同じ weightsで初期化し、各モデルの同じデータである MNIST からの 1000 データポイントで 10 ステップのフルバッチ勾配降下法を使用してトレーニングします。JAX は 80% のトレーニング精度で終了しますが、Keras は 90% で終了します。具体的には、次の出力が得られます。

実際、条件を少し変えると (異なるランダムな初期重みまたは異なるトレーニング セットを使用して)、50% の JAX 精度と 90% の Keras 精度が得られることがあります。

最後に重みを交換して、トレーニングから得られた重みが実際に問題であり、ネットワーク予測の実際の計算や精度の計算方法とは関係がないことを確認します。

コード:

057 行目の PRNG シードを、異なる初期重みを使用して実験を実行する以外の値に変更してみてください。

0 投票する
1 に答える
247 参照

python - 長さが異なる JAX バッチ処理

私は関数を持ってcompute(x)xますjnp.ndarrayvmap今、私はそれを配列のバッチを取る関数に変換しx[i]、それjitを高速化するために使用したいと考えています。compute(x)次のようなものです:

ただし、各配列x[i]の長さは異なります。Nこの問題は、配列に末尾のゼロをパディングして、すべて同じ長さになりvmap(compute)、 shape のバッチに適用できるようにすることで簡単に回避できます(batch_size, N)

ただし、そうすると、very_expensive_function()各配列の末尾のゼロに対しても呼び出されることになりますx[i]。とに干渉することなく、 のスライスでのみ呼び出されるものcompute()を変更する方法はありますか?very_expensive_function()xvmapjit