Weka の API には、分類予測分布を取得するために使用できる Classifier.distributionForInstance() というメソッドがあります。次に、確率を減らして分布を並べ替え、上位 N 個の予測を取得できます。
以下は、出力する関数です。(1) テスト インスタンスのグラウンド トゥルース ラベル。(2) classifyInstance() から予測されたラベル。(3) distributionForInstance() からの予測分布。私はこれを J48 で使用しましたが、他の分類器でも動作するはずです。
入力パラメーターは、シリアル化されたモデル ファイル (モデルのトレーニング フェーズ中に -d オプションを適用して作成できます) と、ARFF 形式のテスト ファイルです。
public void test(String modelFileSerialized, String testFileARFF)
throws Exception
{
// Deserialize the classifier.
Classifier classifier =
(Classifier) weka.core.SerializationHelper.read(
modelFileSerialized);
// Load the test instances.
Instances testInstances = DataSource.read(testFileARFF);
// Mark the last attribute in each instance as the true class.
testInstances.setClassIndex(testInstances.numAttributes()-1);
int numTestInstances = testInstances.numInstances();
System.out.printf("There are %d test instances\n", numTestInstances);
// Loop over each test instance.
for (int i = 0; i < numTestInstances; i++)
{
// Get the true class label from the instance's own classIndex.
String trueClassLabel =
testInstances.instance(i).toString(testInstances.classIndex());
// Make the prediction here.
double predictionIndex =
classifier.classifyInstance(testInstances.instance(i));
// Get the predicted class label from the predictionIndex.
String predictedClassLabel =
testInstances.classAttribute().value((int) predictionIndex);
// Get the prediction probability distribution.
double[] predictionDistribution =
classifier.distributionForInstance(testInstances.instance(i));
// Print out the true label, predicted label, and the distribution.
System.out.printf("%5d: true=%-10s, predicted=%-10s, distribution=",
i, trueClassLabel, predictedClassLabel);
// Loop over all the prediction labels in the distribution.
for (int predictionDistributionIndex = 0;
predictionDistributionIndex < predictionDistribution.length;
predictionDistributionIndex++)
{
// Get this distribution index's class label.
String predictionDistributionIndexAsClassLabel =
testInstances.classAttribute().value(
predictionDistributionIndex);
// Get the probability.
double predictionProbability =
predictionDistribution[predictionDistributionIndex];
System.out.printf("[%10s : %6.3f]",
predictionDistributionIndexAsClassLabel,
predictionProbability );
}
o.printf("\n");
}
}