私は配列プログラミングが初めてで、sklearn.metrics label_ranking_average_precision_score 関数を解釈するのが難しいと感じました。それが計算される方法を理解するためにあなたの助けが必要です.Numpy Array Programmingを学ぶためのヒントがあれば感謝します.
一般的に、精度は((True Positive) / (True Positive + False Positive)) であることを知っています。
私が質問している理由は、オーディオ タグ付けの Kaggle コンペティションに出くわし、応答に複数の正しいラベルがある場合、スコアを計算するために LWRAP 関数を使用しているというこの投稿に出くわしたためです。このスコアの計算方法を知りたくて読み始めたのですが、解釈が難しいことに気づきました。私の2つの困難は、
1)ドキュメントからMath関数を解釈することです。スコア計算でランクがどのように使用されるかわかりません
2)コードからNumpy配列操作を解釈
する私が読んでいる関数はGoogle Collabドキュメントからのものであり、ドキュメントを読んでみましたsklearnでしたが、正しく理解できませんでした。
1 つのサンプル計算のコードは次のとおりです。
# Core calculation of label precisions for one test sample.
def _one_sample_positive_class_precisions(scores, truth):
"""Calculate precisions for each true class for a single sample.
Args:
scores: np.array of (num_classes,) giving the individual classifier scores.
truth: np.array of (num_classes,) bools indicating which classes are true.
Returns:
pos_class_indices: np.array of indices of the true classes for this sample.
pos_class_precisions: np.array of precisions corresponding to each of those
classes.
"""
num_classes = scores.shape[0]
pos_class_indices = np.flatnonzero(truth > 0)
# Only calculate precisions if there are some true classes.
if not len(pos_class_indices):
return pos_class_indices, np.zeros(0)
# Retrieval list of classes for this sample.
retrieved_classes = np.argsort(scores)[::-1]
# class_rankings[top_scoring_class_index] == 0 etc.
class_rankings = np.zeros(num_classes, dtype=np.int)
class_rankings[retrieved_classes] = range(num_classes)
# Which of these is a true label?
retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
retrieved_class_true[class_rankings[pos_class_indices]] = True
# Num hits for every truncated retrieval list.
retrieved_cumulative_hits = np.cumsum(retrieved_class_true)
# Precision of retrieval list truncated at each hit, in order of pos_labels.
precision_at_hits = (
retrieved_cumulative_hits[class_rankings[pos_class_indices]] /
(1 + class_rankings[pos_class_indices].astype(np.float)))
return pos_class_indices, precision_at_hits