Spaces:
Runtime error
Runtime error
File size: 4,304 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import ujson
from collections import defaultdict
from colbert.utils.runs import Run
class Metrics:
def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
self.results = {}
self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
self.recall_sums = {depth: 0.0 for depth in recall_depths}
self.success_sums = {depth: 0.0 for depth in success_depths}
self.total_queries = total_queries
self.max_query_idx = -1
self.num_queries_added = 0
def add(self, query_idx, query_key, ranking, gold_positives):
self.num_queries_added += 1
assert query_key not in self.results
assert len(self.results) <= query_idx
assert len(set(gold_positives)) == len(gold_positives)
assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
self.results[query_key] = ranking
positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
if len(positives) == 0:
return
for depth in self.mrr_sums:
first_positive = positives[0]
self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
for depth in self.success_sums:
first_positive = positives[0]
self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
for depth in self.recall_sums:
num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
def print_metrics(self, query_idx):
for depth in sorted(self.mrr_sums):
print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
for depth in sorted(self.success_sums):
print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
for depth in sorted(self.recall_sums):
print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
def log(self, query_idx):
assert query_idx >= self.max_query_idx
self.max_query_idx = query_idx
Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
for depth in sorted(self.mrr_sums):
score = self.mrr_sums[depth] / (query_idx+1.0)
Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
for depth in sorted(self.success_sums):
score = self.success_sums[depth] / (query_idx+1.0)
Run.log_metric("ranking/Success." + str(depth), score, query_idx)
for depth in sorted(self.recall_sums):
score = self.recall_sums[depth] / (query_idx+1.0)
Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
def output_final_metrics(self, path, query_idx, num_queries):
assert query_idx + 1 == num_queries
assert num_queries == self.total_queries
if self.max_query_idx < query_idx:
self.log(query_idx)
self.print_metrics(query_idx)
output = defaultdict(dict)
for depth in sorted(self.mrr_sums):
score = self.mrr_sums[depth] / (query_idx+1.0)
output['mrr'][depth] = score
for depth in sorted(self.success_sums):
score = self.success_sums[depth] / (query_idx+1.0)
output['success'][depth] = score
for depth in sorted(self.recall_sums):
score = self.recall_sums[depth] / (query_idx+1.0)
output['recall'][depth] = score
with open(path, 'w') as f:
ujson.dump(output, f, indent=4)
f.write('\n')
def evaluate_recall(qrels, queries, topK_pids):
if qrels is None:
return
assert set(qrels.keys()) == set(queries.keys())
recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
for qid in qrels]
recall_at_k = sum(recall_at_k) / len(qrels)
recall_at_k = round(recall_at_k, 3)
print("Recall @ maximum depth =", recall_at_k)
# TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.
|