import os from contextlib import contextmanager from colbert.utils.utils import print_message, NullContextManager from colbert.utils.runs import Run class RankingLogger(): def __init__(self, directory, qrels=None, log_scores=False): self.directory = directory self.qrels = qrels self.filename, self.also_save_annotations = None, None self.log_scores = log_scores @contextmanager def context(self, filename, also_save_annotations=False): assert self.filename is None assert self.also_save_annotations is None filename = os.path.join(self.directory, filename) self.filename, self.also_save_annotations = filename, also_save_annotations print_message("#> Logging ranked lists to {}".format(self.filename)) with open(filename, 'w') as f: self.f = f with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g: self.g = g try: yield self finally: pass def log(self, qid, ranking, is_ranked=True, print_positions=[]): print_positions = set(print_positions) f_buffer = [] g_buffer = [] for rank, (score, pid, passage) in enumerate(ranking): is_relevant = self.qrels and int(pid in self.qrels[qid]) rank = rank+1 if is_ranked else -1 possibly_score = [score] if self.log_scores else [] f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n") if self.g: g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n") if rank in print_positions: prefix = "** " if is_relevant else "" prefix += str(rank) print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, ' ', passage) self.f.write(''.join(f_buffer)) if self.g: self.g.write(''.join(g_buffer))