37

Keras で加重バイナリ クロスエントロピーを実装しようとしましたが、コードが正しいかどうかわかりません。トレーニングの出力は少し混乱しているようです。数エポックの後、精度が ~0.15 になりました。それは少なすぎると思います(ランダムな推測であっても)。

一般に、出力には約 11% の 1 と 89% のゼロがあるため、重みは w_zero=0.89 および w_one=0.11 です。

私のコード:

def create_weighted_binary_crossentropy(zero_weight, one_weight):

    def weighted_binary_crossentropy(y_true, y_pred):

        # Original binary crossentropy (see losses.py):
        # K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)

        # Calculate the binary crossentropy
        b_ce = K.binary_crossentropy(y_true, y_pred)

        # Apply the weights
        weight_vector = y_true * one_weight + (1. - y_true) * zero_weight
        weighted_b_ce = weight_vector * b_ce

        # Return the mean error
        return K.mean(weighted_b_ce)

    return weighted_binary_crossentropy

多分誰かが何が間違っているのを見ますか?

ありがとうございました

4

6 に答える 6

23

sklearn モジュールを使用して、次のように各クラスの重みを自動的に計算できます。

# Import
import numpy as np
from sklearn.utils import class_weight

# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))

# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
                                            np.unique(y_train),
                                            y_train)

# Add the class weights to the training                                         
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

class_weight.compute_class_weight()の出力は、次のような numpy 配列であることに注意してください[2.57569845 0.68250928]

于 2019-03-15T10:34:03.343 に答える
2

model.fit でクラスの重みを使用するのは正しくないと思います。{0:0.11, 1:0.89}、ここの 0 はインデックスであり、0 クラスではありません。Kerasドキュメント: https://keras.io/models/sequential/ class_weight: クラス インデックス (整数) を重み (float) 値にマッピングするオプションのディクショナリ。損失関数の重み付けに使用されます (トレーニング中のみ)。これは、少数のクラスからのサンプルに「もっと注意を払う」ようにモデルに指示するのに役立ちます。

于 2017-11-15T01:09:23.897 に答える