1

各層に 16 ~ 32 個のセルを持つ 2 層の LSTM ネットワークをトレーニングしていますが、トレーニング用のかなり不均衡なデータセットがありました。私の 7 つのクラス頻度に基づいて、total_samples/class_frequency の単純な式で計算されたサンプルの重みは [3.7, 5.6, 26.4, 3.2, 191.6, 8.4, 13.2] であり、各サンプルのこの重みを (data のタプルに追加します。 、ラベル) Kerasmodel.fit()関数を実行するためのデータセット ジェネレーターの出力。トレーニング コードは次のとおりです。

model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
mc = ModelCheckpoint(model_file, monitor='val_acc', mode='max', verbose=1, save_best_only=True)
es = EarlyStopping(monitor='val_acc', mode='max', verbose=1, patience=50)
history = model.fit(train_data, epochs=epochs, steps_per_epoch = train_steps, validation_data=val_data,
                            validation_steps = val_steps, verbose=verbose, callbacks=[es, mc])

次に、保存された最適なモデルを使用して評価し、このコードでパフォーマンス統計を計算しました (私のデータは tensorflow データセットにあります)。

saved_model = load_model(model_file)
iterator = test_data.make_one_shot_iterator()
next_element = iterator.get_next()
y_test = y_pred = np.empty(0)
for i in range(test_steps):
    batch = sess.run(next_element)
    x_test_batch = batch[0]
    y_test_batch = batch[1]
    y_pred_batch = saved_model.predict_on_batch(x_test_batch)
    y_test = np.append(y_test, np.argmax(y_test_batch, axis=1))
    y_pred = np.append(y_pred, np.argmax(y_pred_batch, axis=1))
print('\nTest data classification report:\n{}\n'.format(classification_report(y_test, y_pred)))

しかし、出力統計に見られるのは、まれなクラス (最高の重み) であっても、重み付けされた統計は、重み付けされていないもの (すべての重みを等しく 1 に設定) よりも全体的に悪いということです。統計は次のとおりです。

加重実行の場合:

     class     prec.     recall    f1       support
     0.0       1.00      0.97      0.98     79785
     1.0       0.89      0.88      0.88     52614
     2.0       0.61      0.76      0.68     11090
     3.0       0.96      0.93      0.95     91160
     4.0       0.59      0.92      0.72      1530
     5.0       0.89      0.90      0.89     34746
     6.0       0.81      0.87      0.84     22289

accuracy                           0.92    293214
macro avg      0.82      0.89      0.85    293214

重み付けされていない実行の場合:

     class     prec.     recall    f1       support
     0.0       0.99      0.98      0.99     79785
     1.0       0.89      0.90      0.90     52614
     2.0       0.79      0.66      0.72     11090
     3.0       0.95      0.96      0.95     91160
     4.0       0.85      0.82      0.83      1530
     5.0       0.89      0.92      0.90     34746
     6.0       0.88      0.86      0.87     22289

accuracy                           0.93    293214
macro avg      0.89      0.87      0.88    293214

ここで何が問題なのですか?

4

1 に答える 1