import collections import functools import os import re import yaml class AttrDict(dict): """Dict as attribute trick.""" def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self for key, value in self.__dict__.items(): if isinstance(value, dict): self.__dict__[key] = AttrDict(value) elif isinstance(value, (list, tuple)): if isinstance(value[0], dict): self.__dict__[key] = [AttrDict(item) for item in value] else: self.__dict__[key] = value def yaml(self): """Convert object to yaml dict and return.""" yaml_dict = {} for key, value in self.__dict__.items(): if isinstance(value, AttrDict): yaml_dict[key] = value.yaml() elif isinstance(value, list): if isinstance(value[0], AttrDict): new_l = [] for item in value: new_l.append(item.yaml()) yaml_dict[key] = new_l else: yaml_dict[key] = value else: yaml_dict[key] = value return yaml_dict def __repr__(self): """Print all variables.""" ret_str = [] for key, value in self.__dict__.items(): if isinstance(value, AttrDict): ret_str.append('{}:'.format(key)) child_ret_str = value.__repr__().split('\n') for item in child_ret_str: ret_str.append(' ' + item) elif isinstance(value, list): if isinstance(value[0], AttrDict): ret_str.append('{}:'.format(key)) for item in value: # Treat as AttrDict above. child_ret_str = item.__repr__().split('\n') for item in child_ret_str: ret_str.append(' ' + item) else: ret_str.append('{}: {}'.format(key, value)) else: ret_str.append('{}: {}'.format(key, value)) return '\n'.join(ret_str) class Config(AttrDict): r"""Configuration class. This should include every human specifiable hyperparameter values for your training.""" def __init__(self, filename=None, verbose=False, is_train=True): super(Config, self).__init__() # Set default parameters. # Logging. large_number = 1000000000 self.snapshot_save_iter = large_number self.snapshot_save_epoch = large_number self.snapshot_save_start_iter = 0 self.snapshot_save_start_epoch = 0 self.image_save_iter = large_number self.eval_epoch = large_number self.start_eval_epoch = large_number self.eval_epoch = large_number self.max_epoch = large_number self.max_iter = large_number self.logging_iter = 100 self.image_to_tensorboard=False self.which_iter = None self.resume = True self.checkpoints_dir = 'NTED' self.name = 'nted_checkpoint.pt' self.phase = 'train' if is_train else 'test' # Networks. self.gen = AttrDict(type='generators.dummy') self.dis = AttrDict(type='discriminators.dummy') # Optimizers. self.gen_optimizer = AttrDict(type='adam', lr=0.0001, adam_beta1=0.0, adam_beta2=0.999, eps=1e-8, lr_policy=AttrDict(iteration_mode=False, type='step', step_size=large_number, gamma=1)) self.dis_optimizer = AttrDict(type='adam', lr=0.0001, adam_beta1=0.0, adam_beta2=0.999, eps=1e-8, lr_policy=AttrDict(iteration_mode=False, type='step', step_size=large_number, gamma=1)) # Data. self.data = AttrDict(name='dummy', type='datasets.images', num_workers=0) self.test_data = AttrDict(name='dummy', type='datasets.images', num_workers=0, test=AttrDict(is_lmdb=False, roots='', batch_size=1)) self.trainer = AttrDict( image_to_tensorboard=False, hparam_to_tensorboard=False) # Cudnn. self.cudnn = AttrDict(deterministic=False, benchmark=True) # Others. self.pretrained_weight = '' self.inference_args = AttrDict() # Update with given configurations. assert os.path.exists(filename), 'File {} not exist.'.format(filename) loader = yaml.SafeLoader loader.add_implicit_resolver( u'tag:yaml.org,2002:float', re.compile(u'''^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) |\\.(?:nan|NaN|NAN))$''', re.X), list(u'-+0123456789.')) try: with open(filename, 'r') as f: cfg_dict = yaml.load(f, Loader=loader) except EnvironmentError: print('Please check the file with name of "%s"', filename) recursive_update(self, cfg_dict) # Put common opts in both gen and dis. if 'common' in cfg_dict: self.common = AttrDict(**cfg_dict['common']) self.gen.common = self.common self.dis.common = self.common if verbose: print(' config '.center(80, '-')) print(self.__repr__()) print(''.center(80, '-')) def rsetattr(obj, attr, val): """Recursively find object and set value""" pre, _, post = attr.rpartition('.') return setattr(rgetattr(obj, pre) if pre else obj, post, val) def rgetattr(obj, attr, *args): """Recursively find object and return value""" def _getattr(obj, attr): r"""Get attribute.""" return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj] + attr.split('.')) def recursive_update(d, u): """Recursively update AttrDict d with AttrDict u""" for key, value in u.items(): if isinstance(value, collections.abc.Mapping): d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) elif isinstance(value, (list, tuple)): if isinstance(value[0], dict): d.__dict__[key] = [AttrDict(item) for item in value] else: d.__dict__[key] = value else: d.__dict__[key] = value return d