File size: 1,667 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import Any, Dict, List, Union
from datasets import Dataset, DatasetDict
from mmengine.config import Config
from opencompass.registry import TASKS
def get_type_from_cfg(cfg: Union[Config, Dict]) -> Any:
"""Get the object type given MMEngine's Config.
It loads the "type" field and return the corresponding object type.
"""
type = cfg['type']
if isinstance(type, str):
# FIXME: This has nothing to do with any specific registry, to be fixed
# in MMEngine
type = TASKS.get(type)
return type
def _check_type_list(obj, typelist: List):
for _type in typelist:
if _type is None:
if obj is None:
return obj
elif isinstance(obj, _type):
return obj
raise TypeError(
f'Expected an object in {[_.__name__ if _ is not None else None for _ in typelist]} type, but got {obj}' # noqa
)
def _check_dataset(obj) -> Union[Dataset, DatasetDict]:
if isinstance(obj, Dataset) or isinstance(obj, DatasetDict):
return obj
else:
raise TypeError(
f'Expected a datasets.Dataset or a datasets.DatasetDict object, but got {obj}' # noqa
)
def _check_list(obj) -> List:
if isinstance(obj, List):
return obj
else:
raise TypeError(f'Expected a List object, but got {obj}')
def _check_str(obj) -> str:
if isinstance(obj, str):
return obj
else:
raise TypeError(f'Expected a str object, but got {obj}')
def _check_dict(obj) -> Dict:
if isinstance(obj, Dict):
return obj
else:
raise TypeError(f'Expected a Dict object, but got {obj}')
|