File size: 4,304 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import ujson

from collections import defaultdict
from colbert.utils.runs import Run


class Metrics:
    def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
        self.results = {}
        self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
        self.recall_sums = {depth: 0.0 for depth in recall_depths}
        self.success_sums = {depth: 0.0 for depth in success_depths}
        self.total_queries = total_queries

        self.max_query_idx = -1
        self.num_queries_added = 0

    def add(self, query_idx, query_key, ranking, gold_positives):
        self.num_queries_added += 1

        assert query_key not in self.results
        assert len(self.results) <= query_idx
        assert len(set(gold_positives)) == len(gold_positives)
        assert len(set([pid for _, pid, _ in ranking])) == len(ranking)

        self.results[query_key] = ranking

        positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]

        if len(positives) == 0:
            return

        for depth in self.mrr_sums:
            first_positive = positives[0]
            self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0

        for depth in self.success_sums:
            first_positive = positives[0]
            self.success_sums[depth] += 1.0 if first_positive < depth else 0.0

        for depth in self.recall_sums:
            num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
            self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)

    def print_metrics(self, query_idx):
        for depth in sorted(self.mrr_sums):
            print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))

        for depth in sorted(self.success_sums):
            print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))

        for depth in sorted(self.recall_sums):
            print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))

    def log(self, query_idx):
        assert query_idx >= self.max_query_idx
        self.max_query_idx = query_idx

        Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
        Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/MRR." + str(depth), score, query_idx)

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/Success." + str(depth), score, query_idx)

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx+1.0)
            Run.log_metric("ranking/Recall." + str(depth), score, query_idx)

    def output_final_metrics(self, path, query_idx, num_queries):
        assert query_idx + 1 == num_queries
        assert num_queries == self.total_queries

        if self.max_query_idx < query_idx:
            self.log(query_idx)

        self.print_metrics(query_idx)

        output = defaultdict(dict)

        for depth in sorted(self.mrr_sums):
            score = self.mrr_sums[depth] / (query_idx+1.0)
            output['mrr'][depth] = score

        for depth in sorted(self.success_sums):
            score = self.success_sums[depth] / (query_idx+1.0)
            output['success'][depth] = score

        for depth in sorted(self.recall_sums):
            score = self.recall_sums[depth] / (query_idx+1.0)
            output['recall'][depth] = score

        with open(path, 'w') as f:
            ujson.dump(output, f, indent=4)
            f.write('\n')


def evaluate_recall(qrels, queries, topK_pids):
    if qrels is None:
        return

    assert set(qrels.keys()) == set(queries.keys())
    recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
                   for qid in qrels]
    recall_at_k = sum(recall_at_k) / len(qrels)
    recall_at_k = round(recall_at_k, 3)
    print("Recall @ maximum depth =", recall_at_k)


# TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.