File size: 3,262 Bytes
6d1366a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from functools import wraps
from copy import deepcopy
import inspect
import torch.nn as nn


def serialize(init):
    parameters = list(inspect.signature(init).parameters)

    @wraps(init)
    def new_init(self, *args, **kwargs):
        params = deepcopy(kwargs)
        for pname, value in zip(parameters[1:], args):
            params[pname] = value

        config = {
            'class': get_classname(self.__class__),
            'params': dict()
        }
        specified_params = set(params.keys())

        for pname, param in get_default_params(self.__class__).items():
            if pname not in params:
                params[pname] = param.default

        for name, value in list(params.items()):
            param_type = 'builtin'
            if inspect.isclass(value):
                param_type = 'class'
                value = get_classname(value)

            config['params'][name] = {
                'type': param_type,
                'value': value,
                'specified': name in specified_params
            }

        setattr(self, '_config', config)
        init(self, *args, **kwargs)

    return new_init


def load_model(config, eval_ritm, **kwargs):
    model_class = get_class_from_str(config['class'])
    model_default_params = get_default_params(model_class)

    model_args = dict()
    for pname, param in config['params'].items():
        value = param['value']
        if param['type'] == 'class':
            value = get_class_from_str(value)

        if pname not in model_default_params and not param['specified']:
            continue

        assert pname in model_default_params
        if not param['specified'] and model_default_params[pname].default == value:
            continue
        model_args[pname] = value
    model_args.update(kwargs)

    # This ugly hardcode is only to support evalution for RITM models
    # Ignore it if you are evaluting SimpleClick models.
    if eval_ritm:
        model_args['use_rgb_conv'] = True

    return model_class(**model_args)


def get_config_repr(config):
    config_str = f'Model: {config["class"]}\n'
    for pname, param in config['params'].items():
        value = param["value"]
        if param['type'] == 'class':
            value = value.split('.')[-1]
        param_str = f'{pname:<22} = {str(value):<12}'
        if not param['specified']:
            param_str += ' (default)'
        config_str += param_str + '\n'
    return config_str


def get_default_params(some_class):
    params = dict()
    for mclass in some_class.mro():
        if mclass is nn.Module or mclass is object:
            continue
        mclass_params = inspect.signature(mclass.__init__).parameters
        for pname, param in mclass_params.items():
            if param.default != param.empty and pname not in params:
                params[pname] = param

    return params


def get_classname(cls):
    module = cls.__module__
    name = cls.__qualname__
    if module is not None and module != "__builtin__":
        name = module + "." + name
    return name


def get_class_from_str(class_str):
    components = class_str.split('.')
    mod = __import__('.'.join(components[:-1]))
    for comp in components[1:]:
        mod = getattr(mod, comp)
    return mod