File size: 1,714 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os.path as osp
from typing import Dict, List, Union

from mmengine.config import ConfigDict


def model_abbr_from_cfg(cfg: Union[ConfigDict, List[ConfigDict]]) -> str:
    """Generate model abbreviation from the model's confg."""
    if isinstance(cfg, (list, tuple)):
        return '_'.join(model_abbr_from_cfg(c) for c in cfg)
    if 'abbr' in cfg:
        return cfg['abbr']
    model_abbr = cfg['type'] + '_' + '_'.join(
        osp.realpath(cfg['path']).split('/')[-2:])
    model_abbr = model_abbr.replace('/', '_')
    return model_abbr


def dataset_abbr_from_cfg(cfg: ConfigDict) -> str:
    """Returns dataset abbreviation from the dataset's confg."""
    if 'abbr' in cfg:
        return cfg['abbr']
    dataset_abbr = cfg['path']
    if 'name' in cfg:
        dataset_abbr += '_' + cfg['name']
    dataset_abbr = dataset_abbr.replace('/', '_')
    return dataset_abbr


def task_abbr_from_cfg(task: Dict) -> str:
    """Returns task abbreviation from the task's confg."""
    return '[' + ','.join([
        f'{model_abbr_from_cfg(model)}/'
        f'{dataset_abbr_from_cfg(dataset)}'
        for i, model in enumerate(task['models'])
        for dataset in task['datasets'][i]
    ]) + ']'


def get_infer_output_path(model_cfg: ConfigDict,
                          dataset_cfg: ConfigDict,
                          root_path: str = None,
                          file_extension: str = 'json') -> str:
    # TODO: Rename this func
    assert root_path is not None, 'default root_path is not allowed any more'
    model_abbr = model_abbr_from_cfg(model_cfg)
    dataset_abbr = dataset_abbr_from_cfg(dataset_cfg)
    return osp.join(root_path, model_abbr, f'{dataset_abbr}.{file_extension}')