File size: 832 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from omegaconf import DictConfig, OmegaConf


def parse_structured(fields, cfg) -> DictConfig:
    scfg = OmegaConf.structured(fields(**cfg))
    return scfg


def load_config(fields, config, extras=None):
    if extras is not None:
        print("Warning! extra parameter in cli is not verified, may cause erros.")
    if isinstance(config, str):
        cfg = OmegaConf.load(config)
    elif isinstance(config, dict):
        cfg = OmegaConf.create(config)
    elif isinstance(config, DictConfig):
        cfg = config
    else:
        raise NotImplementedError(f"Unsupported config type {type(config)}")
    if extras is not None:
        cli_conf = OmegaConf.from_cli(extras)
        cfg = OmegaConf.merge(cfg, cli_conf)
    OmegaConf.resolve(cfg)
    assert isinstance(cfg, DictConfig)
    return parse_structured(fields, cfg)