3

コンテキスト: ドキュメントのセットがあり、それぞれに 2 つの確率値が関連付けられています。クラス A に属する確率またはクラス B に属する確率です。クラスは相互に排他的であり、確率の合計は 1 になります。したがって、たとえばドキュメント D には、グラウンド トゥルースとして関連付けられた確率 (0.6、0.4) があります。

各ドキュメントは、含まれる用語の tfidf で表され、0 から 1 に正規化されます。また、doc2vec (正規化された形式 -1 から 1) と他のいくつかの方法も試しました。

この確率分布を予測するために、非常に単純なニューラル ネットワークを構築しました。

  • フィーチャと同じ数のノードを持つ入力レイヤー
  • ノードが 1 つの単一の隠れ層
  • ソフトマックスと 2 つのノードを含む出力層
  • 交差エントロピー損失関数
  • 更新機能や学習率も変えてみました

これは、nolearn を使用して記述したコードです。

net = nolearn.lasagne.NeuralNet(
    layers=[('input', layers.InputLayer),
        ('hidden1', layers.DenseLayer),
        ('output', layers.DenseLayer)],
    input_shape=(None, X_train.shape[1]),
    hidden1_num_units=1,
    output_num_units=2,
    output_nonlinearity=lasagne.nonlinearities.softmax,
    objective_loss_function=lasagne.objectives.binary_crossentropy,
    max_epochs=50,
    on_epoch_finished=[es.EarlyStopping(patience=5, gamma=0.0001)],
    regression=True,
    update=lasagne.updates.adam,
    update_learning_rate=0.001,
    verbose=2)
net.fit(X_train, y_train)
y_true, y_pred = y_test, net.predict(X_test)

私の問題は、私の予測にはカットオフポイントがあり、そのポイントを下回る予測はありません (私が何を意味するかを理解するために写真を確認してください)。 このプロットは、真の確率と私の予測の差を示しています。点が赤い線に近いほど、予測は良好です。理想的には、すべての点が線上にあるはずです。どうすればこれを解決できますか?なぜこれが起こっているのですか?

編集:実際には、隠しレイヤーを削除するだけで問題を解決しました:

net = nolearn.lasagne.NeuralNet(
    layers=[('input', layers.InputLayer),
        ('output', layers.DenseLayer)],
    input_shape=(None, X_train.shape[1]),
    output_num_units=2,
    output_nonlinearity=lasagne.nonlinearities.softmax,
    objective_loss_function=lasagne.objectives.binary_crossentropy,
    max_epochs=50,
    on_epoch_finished=[es.EarlyStopping(patience=5, gamma=0.0001)],
    regression=True,
    update=lasagne.updates.adam,
    update_learning_rate=0.001,
    verbose=2)
net.fit(X_train, y_train)
y_true, y_pred = y_test, net.predict(X_test)

しかし、なぜこの問題が発生したのか、非表示レイヤーを削除すると解決したのか、まだ理解できていません。何か案は?

ここに新しいプロットがあります: 2

4

1 に答える 1

0

トレーニング セットの出力値は [0,1] または [1,0] にする必要があると思います。
[0.6,0.4] は softmax/Crossentropy には適していません。

于 2016-08-25T13:57:41.290 に答える