Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import tqdm | |
import ujson | |
import random | |
from argparse import ArgumentParser | |
from collections import OrderedDict | |
from colbert.utils.utils import print_message, file_tqdm | |
def main(args): | |
qid_to_file_idx = {} | |
for qrels_idx, qrels in enumerate(args.all_queries): | |
with open(qrels) as f: | |
for line in f: | |
qid, *_ = line.strip().split('\t') | |
qid = int(qid) | |
assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx) | |
qid_to_file_idx[qid] = qrels_idx | |
all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))] | |
assert all(not os.path.exists(path) for path in all_outputs_paths) | |
all_outputs = [open(path, 'w') for path in all_outputs_paths] | |
with open(args.ranking) as f: | |
print_message(f"#> Loading ranked lists from {f.name} ..") | |
last_file_idx = -1 | |
for line in file_tqdm(f): | |
qid, *_ = line.strip().split('\t') | |
file_idx = qid_to_file_idx[int(qid)] | |
if file_idx != last_file_idx: | |
print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}") | |
last_file_idx = file_idx | |
all_outputs[file_idx].write(line) | |
print() | |
for f in all_outputs: | |
print(f.name) | |
f.close() | |
print("#> Done!") | |
if __name__ == "__main__": | |
random.seed(12345) | |
parser = ArgumentParser(description='.') | |
# Input Arguments | |
parser.add_argument('--ranking', dest='ranking', required=True, type=str) | |
parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+') | |
args = parser.parse_args() | |
main(args) | |