問題タブ [numerical-stability]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票する
1 に答える
690 参照

python - ゼロからのバイナリ クロス エントロピーの実装 - ニューラル ネットワークのトレーニングで一貫性のない結果が得られる

JAXライブラリとその小さなニューラルネットワークサブモジュール「Stax」を使用して、ニューラルネットワークを実装およびトレーニングしようとしています。このライブラリにはバイナリ クロス エントロピーの実装が付属していないため、独自に作成しました。

単純なニューラル ネットワークを実装して MNIST でトレーニングしたところ、得られた結果の一部に疑いを持ち始めました。そこで、Keras で同じセットアップを実装したところ、すぐにまったく異なる結果が得られました。同じデータで同じ方法でトレーニングされた同じモデルは、JAX では約 50% でしたが、Keras では 90% のトレーニング精度が得られました。最終的に、数値的に不安定であると思われるクロスエントロピーの単純な実装に問題の一部を突き止めました。この投稿と見つけたこのコードに従って、次の新しいバージョンを作成しました。

これは少しうまくいきます。現在、私の JAX 実装は最大 80% のトレーニング精度を達成していますが、それでも Keras の 90% よりもはるかに低いです。私が知りたいのは、何が起こっているのですか?2 つの実装が同じように動作しないのはなぜですか?

以下では、2 つの実装を 1 つのスクリプトに要約しました。このスクリプトでは、JAX と Keras で同じモデルを実装しています。両方を同じ weightsで初期化し、各モデルの同じデータである MNIST からの 1000 データポイントで 10 ステップのフルバッチ勾配降下法を使用してトレーニングします。JAX は 80% のトレーニング精度で終了しますが、Keras は 90% で終了します。具体的には、次の出力が得られます。

実際、条件を少し変えると (異なるランダムな初期重みまたは異なるトレーニング セットを使用して)、50% の JAX 精度と 90% の Keras 精度が得られることがあります。

最後に重みを交換して、トレーニングから得られた重みが実際に問題であり、ネットワーク予測の実際の計算や精度の計算方法とは関係がないことを確認します。

コード:

057 行目の PRNG シードを、異なる初期重みを使用して実験を実行する以外の値に変更してみてください。

0 投票する
1 に答える
296 参照

python - ソフトマックスのテンソルフローの問題

を使用して確率を生成nanまたは計算している Tensorflow マルチクラス分類器があります。次のスニペットを参照してください ( 6 つのクラスがあり、出力がワンホット エンコードされているため、形状はあります)。は 1024 です。inftf.nn.softmaxlogitsbatch_size x 6batch_size

分類子は、nanまたはinfで見つかった最後のステートメントで失敗しますprobabilitieslogitsそうでなければ、最初のステートメントは失敗していたでしょう。

私が読んだことからtf.nn.softmax、ロジットで非常に大きな値と非常に小さな値を処理できます。対話モードでこれを確認しました。

次に、値をクリップしてみましたがlogits、すべてが機能するようになりました。以下の変更されたスニペットを参照してください。

2 番目のステートメントでは、値logitsを -15 と 15 にクリッピングしています。これにより、softmax 計算でnan/が何らかの形で防止されます。infそのため、当面の問題を修正することができました。

しかし、なぜこのクリッピングが機能しているのか、まだわかりませんか? (-20 と 20 の間のクリッピングは機能せず、モデルはnanまたはinfで失敗することに注意してくださいprobabilities)。

なぜこれが当てはまるのか、誰かが理解するのを手伝ってくれますか?

64 ビット インスタンスで実行されている tensorflow 1.15.0 を使用しています。