4

scikit-learn 0.12.1 を使用して次のことを試みています。

  1. LogisticRegression 分類子を訓練する
  2. 差し出された検証データで分類子を評価する
  3. この分類子に新しいデータをフィードし、観測ごとに最も可能性の高い 5 つのラベルを取得します

Sklearn は、1 つの特殊性を除いて、これらすべてを非常に簡単にします。分類子に適合させるために使用されるデータで、可能なすべてのラベルが発生するという保証はありません。可能なラベルは何百もありますが、そのうちのいくつかは利用可能なトレーニング データに含まれていません。

これにより、次の 2 つの問題が発生します。

  1. ラベル ベクトライザーは、以前に表示されなかったラベルが検証データに含まれている場合、それを認識しません。これは、可能なラベルのセットにラベラーを適合させることで簡単に修正できますが、問題 2 を悪化させます。
  2. LogisticRegression 分類子の predict_proba メソッドの出力は、[n_samples, n_classes] 配列です。ここで、n_classes は、トレーニング データに見られるクラスのみで構成されます。これは、predict_proba 配列で argsort を実行しても、ラベル ベクトライザーの語彙に直接マップされる値が提供されなくなったことを意味します。

私の質問は、分類器に可能なクラスの完全なセットを強制的に認識させる最良の方法は何ですか?それらの一部がトレーニングデータに含まれていない場合でも? 明らかに、データを見たことがないラベルについて学習するのに問題がありますが、私の状況では 0 は完全に使用可能です。

4

3 に答える 3

3

によって返されるような配列が必要であるが、predict_probasorted に対応する列がある場合は、次のようにしますall_classes

all_classes = numpy.array(sorted(all_classes))
# Get the probabilities for learnt classes
prob = clf.predict_proba(test_samples)
# Create the result matrix, where all values are initially zero
new_prob = numpy.zeros((prob.shape[0], all_classes.size))
# Set the columns corresponding to clf.classes_
new_prob[:, all_classes.searchsorted(clf.classes_)] = prob
于 2013-03-02T13:56:10.883 に答える
2

larsmanの優れた答えに基づいて、私はこれになりました:

from itertools import repeat
import numpy as np

# determine the classes that were not present in the training set;
# the ones that were are listed in clf.classes_.
classes_not_trained = set(clf.classes_).symmetric_difference(all_classes)

# the order of classes in predict_proba's output matches that in clf.classes_.
prob = clf.predict_proba(test_samples)
new_prob = []
for row in prob:
    prob_per_class = zip(clf.classes_, prob) + zip(classes_not_trained, repeat(0.))
    # put the probabilities in class order
    prob_per_class = sorted(prob_per_class)
    new_prob.append(i[1] for i in prob_per_class)
new_prob = np.asarray(new_prob)

new_prob は [n_samples, n_classes] 配列で、predict_proba からの出力と同じように、以前に見えなかったクラスの確率が 0 になることを除いて異なります。

于 2013-02-25T19:24:06.490 に答える