Spaces:
Runtime error
Runtime error
import collections | |
import functools | |
import os | |
import re | |
import yaml | |
class AttrDict(dict): | |
"""Dict as attribute trick.""" | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
for key, value in self.__dict__.items(): | |
if isinstance(value, dict): | |
self.__dict__[key] = AttrDict(value) | |
elif isinstance(value, (list, tuple)): | |
if isinstance(value[0], dict): | |
self.__dict__[key] = [AttrDict(item) for item in value] | |
else: | |
self.__dict__[key] = value | |
def yaml(self): | |
"""Convert object to yaml dict and return.""" | |
yaml_dict = {} | |
for key, value in self.__dict__.items(): | |
if isinstance(value, AttrDict): | |
yaml_dict[key] = value.yaml() | |
elif isinstance(value, list): | |
if isinstance(value[0], AttrDict): | |
new_l = [] | |
for item in value: | |
new_l.append(item.yaml()) | |
yaml_dict[key] = new_l | |
else: | |
yaml_dict[key] = value | |
else: | |
yaml_dict[key] = value | |
return yaml_dict | |
def __repr__(self): | |
"""Print all variables.""" | |
ret_str = [] | |
for key, value in self.__dict__.items(): | |
if isinstance(value, AttrDict): | |
ret_str.append('{}:'.format(key)) | |
child_ret_str = value.__repr__().split('\n') | |
for item in child_ret_str: | |
ret_str.append(' ' + item) | |
elif isinstance(value, list): | |
if isinstance(value[0], AttrDict): | |
ret_str.append('{}:'.format(key)) | |
for item in value: | |
# Treat as AttrDict above. | |
child_ret_str = item.__repr__().split('\n') | |
for item in child_ret_str: | |
ret_str.append(' ' + item) | |
else: | |
ret_str.append('{}: {}'.format(key, value)) | |
else: | |
ret_str.append('{}: {}'.format(key, value)) | |
return '\n'.join(ret_str) | |
class Config(AttrDict): | |
r"""Configuration class. This should include every human specifiable | |
hyperparameter values for your training.""" | |
def __init__(self, filename=None, verbose=False, is_train=True): | |
super(Config, self).__init__() | |
# Set default parameters. | |
# Logging. | |
large_number = 1000000000 | |
self.snapshot_save_iter = large_number | |
self.snapshot_save_epoch = large_number | |
self.snapshot_save_start_iter = 0 | |
self.snapshot_save_start_epoch = 0 | |
self.image_save_iter = large_number | |
self.eval_epoch = large_number | |
self.start_eval_epoch = large_number | |
self.eval_epoch = large_number | |
self.max_epoch = large_number | |
self.max_iter = large_number | |
self.logging_iter = 100 | |
self.image_to_tensorboard=False | |
self.which_iter = None | |
self.resume = True | |
self.checkpoints_dir = 'NTED' | |
self.name = 'nted_checkpoint.pt' | |
self.phase = 'train' if is_train else 'test' | |
# Networks. | |
self.gen = AttrDict(type='generators.dummy') | |
self.dis = AttrDict(type='discriminators.dummy') | |
# Optimizers. | |
self.gen_optimizer = AttrDict(type='adam', | |
lr=0.0001, | |
adam_beta1=0.0, | |
adam_beta2=0.999, | |
eps=1e-8, | |
lr_policy=AttrDict(iteration_mode=False, | |
type='step', | |
step_size=large_number, | |
gamma=1)) | |
self.dis_optimizer = AttrDict(type='adam', | |
lr=0.0001, | |
adam_beta1=0.0, | |
adam_beta2=0.999, | |
eps=1e-8, | |
lr_policy=AttrDict(iteration_mode=False, | |
type='step', | |
step_size=large_number, | |
gamma=1)) | |
# Data. | |
self.data = AttrDict(name='dummy', | |
type='datasets.images', | |
num_workers=0) | |
self.test_data = AttrDict(name='dummy', | |
type='datasets.images', | |
num_workers=0, | |
test=AttrDict(is_lmdb=False, | |
roots='', | |
batch_size=1)) | |
self.trainer = AttrDict( | |
image_to_tensorboard=False, | |
hparam_to_tensorboard=False) | |
# Cudnn. | |
self.cudnn = AttrDict(deterministic=False, | |
benchmark=True) | |
# Others. | |
self.pretrained_weight = '' | |
self.inference_args = AttrDict() | |
# Update with given configurations. | |
assert os.path.exists(filename), 'File {} not exist.'.format(filename) | |
loader = yaml.SafeLoader | |
loader.add_implicit_resolver( | |
u'tag:yaml.org,2002:float', | |
re.compile(u'''^(?: | |
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | |
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | |
|\\.[0-9_]+(?:[eE][-+][0-9]+)? | |
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* | |
|[-+]?\\.(?:inf|Inf|INF) | |
|\\.(?:nan|NaN|NAN))$''', re.X), | |
list(u'-+0123456789.')) | |
try: | |
with open(filename, 'r') as f: | |
cfg_dict = yaml.load(f, Loader=loader) | |
except EnvironmentError: | |
print('Please check the file with name of "%s"', filename) | |
recursive_update(self, cfg_dict) | |
# Put common opts in both gen and dis. | |
if 'common' in cfg_dict: | |
self.common = AttrDict(**cfg_dict['common']) | |
self.gen.common = self.common | |
self.dis.common = self.common | |
if verbose: | |
print(' config '.center(80, '-')) | |
print(self.__repr__()) | |
print(''.center(80, '-')) | |
def rsetattr(obj, attr, val): | |
"""Recursively find object and set value""" | |
pre, _, post = attr.rpartition('.') | |
return setattr(rgetattr(obj, pre) if pre else obj, post, val) | |
def rgetattr(obj, attr, *args): | |
"""Recursively find object and return value""" | |
def _getattr(obj, attr): | |
r"""Get attribute.""" | |
return getattr(obj, attr, *args) | |
return functools.reduce(_getattr, [obj] + attr.split('.')) | |
def recursive_update(d, u): | |
"""Recursively update AttrDict d with AttrDict u""" | |
for key, value in u.items(): | |
if isinstance(value, collections.abc.Mapping): | |
d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value) | |
elif isinstance(value, (list, tuple)): | |
if isinstance(value[0], dict): | |
d.__dict__[key] = [AttrDict(item) for item in value] | |
else: | |
d.__dict__[key] = value | |
else: | |
d.__dict__[key] = value | |
return d | |