|
import os.path as osp |
|
from typing import Dict, List, Union |
|
|
|
from mmengine.config import ConfigDict |
|
|
|
|
|
def model_abbr_from_cfg(cfg: Union[ConfigDict, List[ConfigDict]]) -> str: |
|
"""Generate model abbreviation from the model's confg.""" |
|
if isinstance(cfg, (list, tuple)): |
|
return '_'.join(model_abbr_from_cfg(c) for c in cfg) |
|
if 'abbr' in cfg: |
|
return cfg['abbr'] |
|
model_abbr = cfg['type'] + '_' + '_'.join( |
|
osp.realpath(cfg['path']).split('/')[-2:]) |
|
model_abbr = model_abbr.replace('/', '_') |
|
return model_abbr |
|
|
|
|
|
def dataset_abbr_from_cfg(cfg: ConfigDict) -> str: |
|
"""Returns dataset abbreviation from the dataset's confg.""" |
|
if 'abbr' in cfg: |
|
return cfg['abbr'] |
|
dataset_abbr = cfg['path'] |
|
if 'name' in cfg: |
|
dataset_abbr += '_' + cfg['name'] |
|
dataset_abbr = dataset_abbr.replace('/', '_') |
|
return dataset_abbr |
|
|
|
|
|
def task_abbr_from_cfg(task: Dict) -> str: |
|
"""Returns task abbreviation from the task's confg.""" |
|
return '[' + ','.join([ |
|
f'{model_abbr_from_cfg(model)}/' |
|
f'{dataset_abbr_from_cfg(dataset)}' |
|
for i, model in enumerate(task['models']) |
|
for dataset in task['datasets'][i] |
|
]) + ']' |
|
|
|
|
|
def get_infer_output_path(model_cfg: ConfigDict, |
|
dataset_cfg: ConfigDict, |
|
root_path: str = None, |
|
file_extension: str = 'json') -> str: |
|
|
|
assert root_path is not None, 'default root_path is not allowed any more' |
|
model_abbr = model_abbr_from_cfg(model_cfg) |
|
dataset_abbr = dataset_abbr_from_cfg(dataset_cfg) |
|
return osp.join(root_path, model_abbr, f'{dataset_abbr}.{file_extension}') |
|
|