File size: 1,334 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Split the ranked lists after retrieval with a merged query set.
"""

import os
import random

from argparse import ArgumentParser


def main(args):
    output_paths = ['{}.{}'.format(args.ranking, split) for split in args.names]
    assert all(not os.path.exists(path) for path in output_paths), output_paths

    output_files = [open(path, 'w') for path in output_paths]

    with open(args.ranking) as f:
        for line in f:
            qid, pid, rank, *other = line.strip().split('\t')
            qid = int(qid)
            split_output_path = output_files[qid // args.gap - 1]
            qid = qid % args.gap

            split_output_path.write('\t'.join([str(x) for x in [qid, pid, rank, *other]]) + '\n')
        
        print(f.name)
    
    _ = [f.close() for f in output_files]
    
    print("#> Done!")


if __name__ == "__main__":
    random.seed(12345)

    parser = ArgumentParser(description='Subsample the dev set.')
    parser.add_argument('--ranking', dest='ranking', required=True)

    parser.add_argument('--names', dest='names', required=False, default=['train', 'dev', 'test'], type=str, nargs='+')  # order matters!
    parser.add_argument('--gap', dest='gap', required=False, default=1_000_000_000, type=int)  # larger than any individual query set

    args = parser.parse_args()

    main(args)