File size: 3,604 Bytes
3bbb319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp

from mmcv import Config


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert benchmark model list to script')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--port', type=int, default=29666, help='dist port')
    parser.add_argument(
        '--work-dir',
        default='tools/batch_test',
        help='the dir to save metric')
    parser.add_argument(
        '--run', action='store_true', help='run script directly')
    parser.add_argument(
        '--out', type=str, help='path to save model benchmark script')

    args = parser.parse_args()
    return args


def process_model_info(model_info, work_dir):
    config = model_info['config'].strip()
    fname, _ = osp.splitext(osp.basename(config))
    job_name = fname
    work_dir = osp.join(work_dir, fname)
    checkpoint = model_info['checkpoint'].strip()
    if not isinstance(model_info['eval'], list):
        evals = [model_info['eval']]
    else:
        evals = model_info['eval']
    eval = ' '.join(evals)
    return dict(
        config=config,
        job_name=job_name,
        work_dir=work_dir,
        checkpoint=checkpoint,
        eval=eval)


def create_test_bash_info(commands, model_test_dict, port, script_name,
                          partition):
    config = model_test_dict['config']
    job_name = model_test_dict['job_name']
    checkpoint = model_test_dict['checkpoint']
    work_dir = model_test_dict['work_dir']
    eval = model_test_dict['eval']

    echo_info = f' \necho \'{config}\' &'
    commands.append(echo_info)
    commands.append('\n')

    command_info = f'GPUS=8  GPUS_PER_NODE=8  ' \
                   f'CPUS_PER_TASK=2 {script_name} '

    command_info += f'{partition} '
    command_info += f'{job_name} '
    command_info += f'{config} '
    command_info += f'$CHECKPOINT_DIR/{checkpoint} '
    command_info += f'--work-dir {work_dir} '

    command_info += f'--eval {eval} '
    command_info += f'--cfg-option dist_params.port={port} '
    command_info += ' &'

    commands.append(command_info)


def main():
    args = parse_args()
    if args.out:
        out_suffix = args.out.split('.')[-1]
        assert args.out.endswith('.sh'), \
            f'Expected out file path suffix is .sh, but get .{out_suffix}'
    assert args.out or args.run, \
        ('Please specify at least one operation (save/run/ the '
         'script) with the argument "--out" or "--run"')

    commands = []
    partition_name = 'PARTITION=$1 '
    commands.append(partition_name)
    commands.append('\n')

    checkpoint_root = 'CHECKPOINT_DIR=$2 '
    commands.append(checkpoint_root)
    commands.append('\n')

    script_name = osp.join('tools', 'slurm_test.sh')
    port = args.port
    work_dir = args.work_dir

    cfg = Config.fromfile(args.config)

    for model_key in cfg:
        model_infos = cfg[model_key]
        if not isinstance(model_infos, list):
            model_infos = [model_infos]
        for model_info in model_infos:
            print('processing: ', model_info['config'])
            model_test_dict = process_model_info(model_info, work_dir)
            create_test_bash_info(commands, model_test_dict, port, script_name,
                                  '$PARTITION')
            port += 1

    command_str = ''.join(commands)
    if args.out:
        with open(args.out, 'w') as f:
            f.write(command_str)
    if args.run:
        os.system(command_str)


if __name__ == '__main__':
    main()