GraCo / isegm /utils /serialization.py
zhaoyian01's picture
Add application file
6d1366a
raw
history blame
3.26 kB
from functools import wraps
from copy import deepcopy
import inspect
import torch.nn as nn
def serialize(init):
parameters = list(inspect.signature(init).parameters)
@wraps(init)
def new_init(self, *args, **kwargs):
params = deepcopy(kwargs)
for pname, value in zip(parameters[1:], args):
params[pname] = value
config = {
'class': get_classname(self.__class__),
'params': dict()
}
specified_params = set(params.keys())
for pname, param in get_default_params(self.__class__).items():
if pname not in params:
params[pname] = param.default
for name, value in list(params.items()):
param_type = 'builtin'
if inspect.isclass(value):
param_type = 'class'
value = get_classname(value)
config['params'][name] = {
'type': param_type,
'value': value,
'specified': name in specified_params
}
setattr(self, '_config', config)
init(self, *args, **kwargs)
return new_init
def load_model(config, eval_ritm, **kwargs):
model_class = get_class_from_str(config['class'])
model_default_params = get_default_params(model_class)
model_args = dict()
for pname, param in config['params'].items():
value = param['value']
if param['type'] == 'class':
value = get_class_from_str(value)
if pname not in model_default_params and not param['specified']:
continue
assert pname in model_default_params
if not param['specified'] and model_default_params[pname].default == value:
continue
model_args[pname] = value
model_args.update(kwargs)
# This ugly hardcode is only to support evalution for RITM models
# Ignore it if you are evaluting SimpleClick models.
if eval_ritm:
model_args['use_rgb_conv'] = True
return model_class(**model_args)
def get_config_repr(config):
config_str = f'Model: {config["class"]}\n'
for pname, param in config['params'].items():
value = param["value"]
if param['type'] == 'class':
value = value.split('.')[-1]
param_str = f'{pname:<22} = {str(value):<12}'
if not param['specified']:
param_str += ' (default)'
config_str += param_str + '\n'
return config_str
def get_default_params(some_class):
params = dict()
for mclass in some_class.mro():
if mclass is nn.Module or mclass is object:
continue
mclass_params = inspect.signature(mclass.__init__).parameters
for pname, param in mclass_params.items():
if param.default != param.empty and pname not in params:
params[pname] = param
return params
def get_classname(cls):
module = cls.__module__
name = cls.__qualname__
if module is not None and module != "__builtin__":
name = module + "." + name
return name
def get_class_from_str(class_str):
components = class_str.split('.')
mod = __import__('.'.join(components[:-1]))
for comp in components[1:]:
mod = getattr(mod, comp)
return mod