欧卫
'add_app_files'
58627fa
from colbert.infra.run import Run
from colbert.data.collection import Collection
import os
import sys
import git
import tqdm
import ujson
import random
from argparse import ArgumentParser
from multiprocessing import Pool
from colbert.utils.utils import groupby_first_item, print_message
from utility.utils.qa_loaders import load_qas_, load_collection_
from utility.utils.save_metadata import format_metadata, get_metadata
from utility.evaluate.annotate_EM_helpers import *
from colbert.data.ranking import Ranking
class AnnotateEM:
def __init__(self, collection, qas):
# TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below.
qas = load_qas_(qas)
collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True)
self.parallel_pool = Pool(30)
print_message('#> Tokenize the answers in the Q&As in parallel...')
qas = list(self.parallel_pool.map(tokenize_all_answers, qas))
qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))
self.qas, self.collection = qas, collection
self.qid2answers = qid2answers
def annotate(self, ranking):
rankings = Ranking.cast(ranking)
# print(len(rankings), rankings[0])
print_message('#> Lookup passages from PIDs...')
expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid])
for qid, pid, rank, *_ in rankings.tolist()]
print_message('#> Assign labels in parallel...')
labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings)))
# Dump output.
self.qid2rankings = groupby_first_item(labeled_rankings)
self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings)
# Evaluation metrics and depths.
self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings)
print(rankings.provenance(), self.success)
return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance()))
def _compute_labels(self, qid2answers, qid2rankings):
cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
success = {cutoff: 0.0 for cutoff in cutoffs}
counts = {cutoff: 0.0 for cutoff in cutoffs}
for qid in qid2answers:
if qid not in qid2rankings:
continue
prev_rank = 0 # ranks should start at one (i.e., and not zero)
labels = []
for pid, rank, label in qid2rankings[qid]:
assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
prev_rank = rank
labels.append(label)
for cutoff in cutoffs:
if cutoff != 'all':
success[cutoff] += sum(labels[:cutoff]) > 0
counts[cutoff] += sum(labels[:cutoff])
else:
success[cutoff] += sum(labels) > 0
counts[cutoff] += sum(labels)
return success, counts
def save(self, new_path):
print_message("#> Dumping output to", new_path, "...")
Ranking(data=self.qid2rankings).save(new_path)
# Dump metrics.
with Run().open(f'{new_path}.metrics', 'w') as f:
d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries}
extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else ''
d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()}
d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()}
# d['arguments'] = get_metadata(args) # TODO: Need arguments...
f.write(format_metadata(d) + '\n')
if __name__ == '__main__':
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv'
r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv'
r = sys.argv[1]
a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv',
qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json')
a.annotate(ranking=r)