6

私は配列プログラミングが初めてで、sklearn.metrics label_ranking_average_precision_score 関数を解釈するのが難しいと感じました。それが計算される方法を理解するためにあなたの助けが必要です.Numpy Array Programmingを学ぶためのヒントがあれば感謝します.


一般的に、精度は((True Positive) / (True Positive + False Positive)) であることを知っています。

私が質問している理由は、オーディオ タグ付けの Kaggle コンペティションに出くわし、応答に複数の正しいラベルがある場合、スコアを計算するために LWRAP 関数を使用しているというこの投稿に出くわしたためです。このスコアの計算方法を知りたくて読み始めたのですが、解釈が難しいことに気づきました。私の2つの困難は、
1)ドキュメントからMath関数を解釈することです。スコア計算でランクがどのように使用されるかわかりません
2)コードからNumpy配列操作を解釈
する私が読んでいる関数はGoogle Collabドキュメントからのものであり、ドキュメントを読んでみましたsklearnでしたが、正しく理解できませんでした。

1 つのサンプル計算のコードは次のとおりです。

# Core calculation of label precisions for one test sample.

def _one_sample_positive_class_precisions(scores, truth):
  """Calculate precisions for each true class for a single sample.

  Args:
    scores: np.array of (num_classes,) giving the individual classifier scores.
    truth: np.array of (num_classes,) bools indicating which classes are true.

  Returns:
    pos_class_indices: np.array of indices of the true classes for this sample.
    pos_class_precisions: np.array of precisions corresponding to each of those
      classes.
  """
  num_classes = scores.shape[0]
  pos_class_indices = np.flatnonzero(truth > 0)
  # Only calculate precisions if there are some true classes.
  if not len(pos_class_indices):
    return pos_class_indices, np.zeros(0)
  # Retrieval list of classes for this sample. 
  retrieved_classes = np.argsort(scores)[::-1]
  # class_rankings[top_scoring_class_index] == 0 etc.
  class_rankings = np.zeros(num_classes, dtype=np.int)
  class_rankings[retrieved_classes] = range(num_classes)
  # Which of these is a true label?
  retrieved_class_true = np.zeros(num_classes, dtype=np.bool)
  retrieved_class_true[class_rankings[pos_class_indices]] = True
  # Num hits for every truncated retrieval list.
  retrieved_cumulative_hits = np.cumsum(retrieved_class_true)
  # Precision of retrieval list truncated at each hit, in order of pos_labels.
  precision_at_hits = (
      retrieved_cumulative_hits[class_rankings[pos_class_indices]] / 
      (1 + class_rankings[pos_class_indices].astype(np.float)))
  return pos_class_indices, precision_at_hits
4

1 に答える 1