Spaces:
Runtime error
Runtime error
import os | |
import ujson | |
import random | |
from argparse import ArgumentParser | |
from colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item | |
from utility.utils.qa_loaders import load_qas_ | |
def main(args): | |
print_message("#> Loading all..") | |
qas = load_qas_(args.qas) | |
rankings = load_ranking(args.ranking) | |
qid2rankings = groupby_first_item(rankings) | |
print_message("#> Subsampling all..") | |
qas_sample = random.sample(qas, args.sample) | |
with open(args.output, 'w') as f: | |
for qid, *_ in qas_sample: | |
for items in qid2rankings[qid]: | |
items = [qid] + items | |
line = '\t'.join(map(str, items)) + '\n' | |
f.write(line) | |
print('\n\n') | |
print(args.output) | |
print("#> Done.") | |
if __name__ == "__main__": | |
random.seed(12345) | |
parser = ArgumentParser(description='Subsample the dev set.') | |
parser.add_argument('--qas', dest='qas', required=True, type=str) | |
parser.add_argument('--ranking', dest='ranking', required=True) | |
parser.add_argument('--output', dest='output', required=True) | |
parser.add_argument('--sample', dest='sample', default=1500, type=int) | |
args = parser.parse_args() | |
assert not os.path.exists(args.output), args.output | |
create_directory(os.path.dirname(args.output)) | |
main(args) | |