TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
1.67 kB
from typing import Any, Dict, List, Union
from datasets import Dataset, DatasetDict
from mmengine.config import Config
from opencompass.registry import TASKS
def get_type_from_cfg(cfg: Union[Config, Dict]) -> Any:
"""Get the object type given MMEngine's Config.
It loads the "type" field and return the corresponding object type.
"""
type = cfg['type']
if isinstance(type, str):
# FIXME: This has nothing to do with any specific registry, to be fixed
# in MMEngine
type = TASKS.get(type)
return type
def _check_type_list(obj, typelist: List):
for _type in typelist:
if _type is None:
if obj is None:
return obj
elif isinstance(obj, _type):
return obj
raise TypeError(
f'Expected an object in {[_.__name__ if _ is not None else None for _ in typelist]} type, but got {obj}' # noqa
)
def _check_dataset(obj) -> Union[Dataset, DatasetDict]:
if isinstance(obj, Dataset) or isinstance(obj, DatasetDict):
return obj
else:
raise TypeError(
f'Expected a datasets.Dataset or a datasets.DatasetDict object, but got {obj}' # noqa
)
def _check_list(obj) -> List:
if isinstance(obj, List):
return obj
else:
raise TypeError(f'Expected a List object, but got {obj}')
def _check_str(obj) -> str:
if isinstance(obj, str):
return obj
else:
raise TypeError(f'Expected a str object, but got {obj}')
def _check_dict(obj) -> Dict:
if isinstance(obj, Dict):
return obj
else:
raise TypeError(f'Expected a Dict object, but got {obj}')