Spaces:
Runtime error
Runtime error
CorvaeOboro
commited on
Commit
β’
6b7f20a
1
Parent(s):
e91c77c
Upload persistence.py
Browse files- torch_utils/persistence.py +251 -0
torch_utils/persistence.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ο»Ώ# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Facilities for pickling Python code alongside other data.
|
10 |
+
|
11 |
+
The pickled code is automatically imported into a separate Python module
|
12 |
+
during unpickling. This way, any previously exported pickles will remain
|
13 |
+
usable even if the original code is no longer available, or if the current
|
14 |
+
version of the code is not consistent with what was originally pickled."""
|
15 |
+
|
16 |
+
import sys
|
17 |
+
import pickle
|
18 |
+
import io
|
19 |
+
import inspect
|
20 |
+
import copy
|
21 |
+
import uuid
|
22 |
+
import types
|
23 |
+
import dnnlib
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
_version = 6 # internal version number
|
28 |
+
_decorators = set() # {decorator_class, ...}
|
29 |
+
_import_hooks = [] # [hook_function, ...]
|
30 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
31 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def persistent_class(orig_class):
|
36 |
+
r"""Class decorator that extends a given class to save its source code
|
37 |
+
when pickled.
|
38 |
+
|
39 |
+
Example:
|
40 |
+
|
41 |
+
from torch_utils import persistence
|
42 |
+
|
43 |
+
@persistence.persistent_class
|
44 |
+
class MyNetwork(torch.nn.Module):
|
45 |
+
def __init__(self, num_inputs, num_outputs):
|
46 |
+
super().__init__()
|
47 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
48 |
+
...
|
49 |
+
|
50 |
+
@persistence.persistent_class
|
51 |
+
class MyLayer(torch.nn.Module):
|
52 |
+
...
|
53 |
+
|
54 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
55 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
56 |
+
and submodules). This way, any previously exported pickle will remain
|
57 |
+
usable even if the class definitions have been modified or are no
|
58 |
+
longer available.
|
59 |
+
|
60 |
+
The decorator saves the source code of the entire Python module
|
61 |
+
containing the decorated class. It does *not* save the source code of
|
62 |
+
any imported modules. Thus, the imported modules must be available
|
63 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
64 |
+
|
65 |
+
It is ok to call functions defined in the same module from the
|
66 |
+
decorated class. However, if the decorated class depends on other
|
67 |
+
classes defined in the same module, they must be decorated as well.
|
68 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
69 |
+
|
70 |
+
It is also possible to employ the decorator just-in-time before
|
71 |
+
calling the constructor. For example:
|
72 |
+
|
73 |
+
cls = MyLayer
|
74 |
+
if want_to_make_it_persistent:
|
75 |
+
cls = persistence.persistent_class(cls)
|
76 |
+
layer = cls(num_inputs, num_outputs)
|
77 |
+
|
78 |
+
As an additional feature, the decorator also keeps track of the
|
79 |
+
arguments that were used to construct each instance of the decorated
|
80 |
+
class. The arguments can be queried via `obj.init_args` and
|
81 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
82 |
+
object state. A typical use case is to first unpickle a previous
|
83 |
+
instance of a persistent class, and then upgrade it to use the latest
|
84 |
+
version of the source code:
|
85 |
+
|
86 |
+
with open('old_pickle.pkl', 'rb') as f:
|
87 |
+
old_net = pickle.load(f)
|
88 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
89 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
90 |
+
"""
|
91 |
+
assert isinstance(orig_class, type)
|
92 |
+
if is_persistent(orig_class):
|
93 |
+
return orig_class
|
94 |
+
|
95 |
+
assert orig_class.__module__ in sys.modules
|
96 |
+
orig_module = sys.modules[orig_class.__module__]
|
97 |
+
orig_module_src = _module_to_src(orig_module)
|
98 |
+
|
99 |
+
class Decorator(orig_class):
|
100 |
+
_orig_module_src = orig_module_src
|
101 |
+
_orig_class_name = orig_class.__name__
|
102 |
+
|
103 |
+
def __init__(self, *args, **kwargs):
|
104 |
+
super().__init__(*args, **kwargs)
|
105 |
+
self._init_args = copy.deepcopy(args)
|
106 |
+
self._init_kwargs = copy.deepcopy(kwargs)
|
107 |
+
assert orig_class.__name__ in orig_module.__dict__
|
108 |
+
_check_pickleable(self.__reduce__())
|
109 |
+
|
110 |
+
@property
|
111 |
+
def init_args(self):
|
112 |
+
return copy.deepcopy(self._init_args)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def init_kwargs(self):
|
116 |
+
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
117 |
+
|
118 |
+
def __reduce__(self):
|
119 |
+
fields = list(super().__reduce__())
|
120 |
+
fields += [None] * max(3 - len(fields), 0)
|
121 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
122 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
123 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
124 |
+
fields[1] = (meta,) # reconstruct args
|
125 |
+
fields[2] = None # state dict
|
126 |
+
return tuple(fields)
|
127 |
+
|
128 |
+
Decorator.__name__ = orig_class.__name__
|
129 |
+
_decorators.add(Decorator)
|
130 |
+
return Decorator
|
131 |
+
|
132 |
+
#----------------------------------------------------------------------------
|
133 |
+
|
134 |
+
def is_persistent(obj):
|
135 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
136 |
+
whether it will save its source code when pickled.
|
137 |
+
"""
|
138 |
+
try:
|
139 |
+
if obj in _decorators:
|
140 |
+
return True
|
141 |
+
except TypeError:
|
142 |
+
pass
|
143 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
144 |
+
|
145 |
+
#----------------------------------------------------------------------------
|
146 |
+
|
147 |
+
def import_hook(hook):
|
148 |
+
r"""Register an import hook that is called whenever a persistent object
|
149 |
+
is being unpickled. A typical use case is to patch the pickled source
|
150 |
+
code to avoid errors and inconsistencies when the API of some imported
|
151 |
+
module has changed.
|
152 |
+
|
153 |
+
The hook should have the following signature:
|
154 |
+
|
155 |
+
hook(meta) -> modified meta
|
156 |
+
|
157 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
158 |
+
|
159 |
+
type: Type of the persistent object, e.g. `'class'`.
|
160 |
+
version: Internal version number of `torch_utils.persistence`.
|
161 |
+
module_src Original source code of the Python module.
|
162 |
+
class_name: Class name in the original Python module.
|
163 |
+
state: Internal state of the object.
|
164 |
+
|
165 |
+
Example:
|
166 |
+
|
167 |
+
@persistence.import_hook
|
168 |
+
def wreck_my_network(meta):
|
169 |
+
if meta.class_name == 'MyNetwork':
|
170 |
+
print('MyNetwork is being imported. I will wreck it!')
|
171 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
172 |
+
return meta
|
173 |
+
"""
|
174 |
+
assert callable(hook)
|
175 |
+
_import_hooks.append(hook)
|
176 |
+
|
177 |
+
#----------------------------------------------------------------------------
|
178 |
+
|
179 |
+
def _reconstruct_persistent_obj(meta):
|
180 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
181 |
+
a persistent object.
|
182 |
+
"""
|
183 |
+
meta = dnnlib.EasyDict(meta)
|
184 |
+
meta.state = dnnlib.EasyDict(meta.state)
|
185 |
+
for hook in _import_hooks:
|
186 |
+
meta = hook(meta)
|
187 |
+
assert meta is not None
|
188 |
+
|
189 |
+
assert meta.version == _version
|
190 |
+
module = _src_to_module(meta.module_src)
|
191 |
+
|
192 |
+
assert meta.type == 'class'
|
193 |
+
orig_class = module.__dict__[meta.class_name]
|
194 |
+
decorator_class = persistent_class(orig_class)
|
195 |
+
obj = decorator_class.__new__(decorator_class)
|
196 |
+
|
197 |
+
setstate = getattr(obj, '__setstate__', None)
|
198 |
+
if callable(setstate):
|
199 |
+
setstate(meta.state) # pylint: disable=not-callable
|
200 |
+
else:
|
201 |
+
obj.__dict__.update(meta.state)
|
202 |
+
return obj
|
203 |
+
|
204 |
+
#----------------------------------------------------------------------------
|
205 |
+
|
206 |
+
def _module_to_src(module):
|
207 |
+
r"""Query the source code of a given Python module.
|
208 |
+
"""
|
209 |
+
src = _module_to_src_dict.get(module, None)
|
210 |
+
if src is None:
|
211 |
+
src = inspect.getsource(module)
|
212 |
+
_module_to_src_dict[module] = src
|
213 |
+
_src_to_module_dict[src] = module
|
214 |
+
return src
|
215 |
+
|
216 |
+
def _src_to_module(src):
|
217 |
+
r"""Get or create a Python module for the given source code.
|
218 |
+
"""
|
219 |
+
module = _src_to_module_dict.get(src, None)
|
220 |
+
if module is None:
|
221 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
222 |
+
module = types.ModuleType(module_name)
|
223 |
+
sys.modules[module_name] = module
|
224 |
+
_module_to_src_dict[module] = src
|
225 |
+
_src_to_module_dict[src] = module
|
226 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
227 |
+
return module
|
228 |
+
|
229 |
+
#----------------------------------------------------------------------------
|
230 |
+
|
231 |
+
def _check_pickleable(obj):
|
232 |
+
r"""Check that the given object is pickleable, raising an exception if
|
233 |
+
it is not. This function is expected to be considerably more efficient
|
234 |
+
than actually pickling the object.
|
235 |
+
"""
|
236 |
+
def recurse(obj):
|
237 |
+
if isinstance(obj, (list, tuple, set)):
|
238 |
+
return [recurse(x) for x in obj]
|
239 |
+
if isinstance(obj, dict):
|
240 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
241 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
242 |
+
return None # Python primitive types are pickleable.
|
243 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
244 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
245 |
+
if is_persistent(obj):
|
246 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
247 |
+
return obj
|
248 |
+
with io.BytesIO() as f:
|
249 |
+
pickle.dump(recurse(obj), f)
|
250 |
+
|
251 |
+
#----------------------------------------------------------------------------
|