Spaces:
Runtime error
Runtime error
File size: 2,495 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from colbert.utils.utils import print_message
from utility.utils.dpr import DPR_normalize, has_answer
def tokenize_all_answers(args):
qid, question, answers = args
return qid, question, [DPR_normalize(ans) for ans in answers]
def assign_label_to_passage(args):
idx, (qid, pid, rank, passage, tokenized_answers) = args
if idx % (1*1000*1000) == 0:
print(idx)
return qid, pid, rank, has_answer(tokenized_answers, passage)
def check_sizes(qid2answers, qid2rankings):
num_judged_queries = len(qid2answers)
num_ranked_queries = len(qid2rankings)
print_message('num_judged_queries =', num_judged_queries)
print_message('num_ranked_queries =', num_ranked_queries)
if num_judged_queries != num_ranked_queries:
assert num_ranked_queries <= num_judged_queries
print('\n\n')
print_message('[WARNING] num_judged_queries != num_ranked_queries')
print('\n\n')
return num_judged_queries, num_ranked_queries
def compute_and_write_labels(output_path, qid2answers, qid2rankings):
cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
success = {cutoff: 0.0 for cutoff in cutoffs}
counts = {cutoff: 0.0 for cutoff in cutoffs}
with open(output_path, 'w') as f:
for qid in qid2answers:
if qid not in qid2rankings:
continue
prev_rank = 0 # ranks should start at one (i.e., and not zero)
labels = []
for pid, rank, label in qid2rankings[qid]:
assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
prev_rank = rank
labels.append(label)
line = '\t'.join(map(str, [qid, pid, rank, int(label)])) + '\n'
f.write(line)
for cutoff in cutoffs:
if cutoff != 'all':
success[cutoff] += sum(labels[:cutoff]) > 0
counts[cutoff] += sum(labels[:cutoff])
else:
success[cutoff] += sum(labels) > 0
counts[cutoff] += sum(labels)
return success, counts
# def dump_metrics(f, nqueries, cutoffs, success, counts):
# for cutoff in cutoffs:
# success_log = "#> P@{} = {}".format(cutoff, success[cutoff] / nqueries)
# counts_log = "#> D@{} = {}".format(cutoff, counts[cutoff] / nqueries)
# print('\n'.join([success_log, counts_log]) + '\n')
# f.write('\n'.join([success_log, counts_log]) + '\n\n')
|