Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,104 Bytes
9d3cb0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
import inspect
import shutil
import tempfile
import typing
from pathlib import Path
import torch
from torch import nn
class BaseModel(nn.Module):
"""This is a class that adds useful save/load functionality to a
``torch.nn.Module`` object. ``BaseModel`` objects can be saved
as ``torch.package`` easily, making them super easy to port between
machines without requiring a ton of dependencies. Files can also be
saved as just weights, in the standard way.
>>> class Model(ml.BaseModel):
>>> def __init__(self, arg1: float = 1.0):
>>> super().__init__()
>>> self.arg1 = arg1
>>> self.linear = nn.Linear(1, 1)
>>>
>>> def forward(self, x):
>>> return self.linear(x)
>>>
>>> model1 = Model()
>>>
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
>>> model1.save(
>>> f.name,
>>> )
>>> model2 = Model.load(f.name)
>>> out2 = seed_and_run(model2, x)
>>> assert torch.allclose(out1, out2)
>>>
>>> model1.save(f.name, package=True)
>>> model2 = Model.load(f.name)
>>> model2.save(f.name, package=False)
>>> model3 = Model.load(f.name)
>>> out3 = seed_and_run(model3, x)
>>>
>>> with tempfile.TemporaryDirectory() as d:
>>> model1.save_to_folder(d, {"data": 1.0})
>>> Model.load_from_folder(d)
"""
EXTERN = [
"audiotools.**",
"tqdm",
"__main__",
"numpy.**",
"julius.**",
"torchaudio.**",
"scipy.**",
"einops",
]
"""Names of libraries that are external to the torch.package saving mechanism.
Source code from these libraries will not be packaged into the model. This can
be edited by the user of this class by editing ``model.EXTERN``."""
INTERN = []
"""Names of libraries that are internal to the torch.package saving mechanism.
Source code from these libraries will be saved alongside the model."""
def save(
self,
path: str,
metadata: dict = None,
package: bool = True,
intern: list = [],
extern: list = [],
mock: list = [],
):
"""Saves the model, either as a torch package, or just as
weights, alongside some specified metadata.
Parameters
----------
path : str
Path to save model to.
metadata : dict, optional
Any metadata to save alongside the model,
by default None
package : bool, optional
Whether to use ``torch.package`` to save the model in
a format that is portable, by default True
intern : list, optional
List of additional libraries that are internal
to the model, used with torch.package, by default []
extern : list, optional
List of additional libraries that are external to
the model, used with torch.package, by default []
mock : list, optional
List of libraries to mock, used with torch.package,
by default []
Returns
-------
str
Path to saved model.
"""
sig = inspect.signature(self.__class__)
args = {}
for key, val in sig.parameters.items():
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
args[key] = arg_val
# Look up attibutes in self, and if any of them are in args,
# overwrite them in args.
for attribute in dir(self):
if attribute in args:
args[attribute] = getattr(self, attribute)
metadata = {} if metadata is None else metadata
metadata["kwargs"] = args
if not hasattr(self, "metadata"):
self.metadata = {}
self.metadata.update(metadata)
if not package:
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
torch.save(state_dict, path)
else:
self._save_package(path, intern=intern, extern=extern, mock=mock)
return path
@property
def device(self):
"""Gets the device the model is on by looking at the device of
the first parameter. May not be valid if model is split across
multiple devices.
"""
return list(self.parameters())[0].device
@classmethod
def load(
cls,
location: str,
*args,
package_name: str = None,
strict: bool = False,
**kwargs,
):
"""Load model from a path. Tries first to load as a package, and if
that fails, tries to load as weights. The arguments to the class are
specified inside the model weights file.
Parameters
----------
location : str
Path to file.
package_name : str, optional
Name of package, by default ``cls.__name__``.
strict : bool, optional
Ignore unmatched keys, by default False
kwargs : dict
Additional keyword arguments to the model instantiation, if
not loading from package.
Returns
-------
BaseModel
A model that inherits from BaseModel.
"""
try:
model = cls._load_package(location, package_name=package_name)
except:
model_dict = torch.load(location, "cpu")
metadata = model_dict["metadata"]
metadata["kwargs"].update(kwargs)
sig = inspect.signature(cls)
class_keys = list(sig.parameters.keys())
for k in list(metadata["kwargs"].keys()):
if k not in class_keys:
metadata["kwargs"].pop(k)
model = cls(*args, **metadata["kwargs"])
model.load_state_dict(model_dict["state_dict"], strict=strict)
model.metadata = metadata
return model
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
package_name = type(self).__name__
resource_name = f"{type(self).__name__}.pth"
# Below is for loading and re-saving a package.
if hasattr(self, "importer"):
kwargs["importer"] = (self.importer, torch.package.sys_importer)
del self.importer
# Why do we use a tempfile, you ask?
# It's so we can load a packaged model and then re-save
# it to the same location. torch.package throws an
# error if it's loading and writing to the same
# file (this is undocumented).
with tempfile.NamedTemporaryFile(suffix=".pth") as f:
with torch.package.PackageExporter(f.name, **kwargs) as exp:
exp.intern(self.INTERN + intern)
exp.mock(mock)
exp.extern(self.EXTERN + extern)
exp.save_pickle(package_name, resource_name, self)
if hasattr(self, "metadata"):
exp.save_pickle(
package_name, f"{package_name}.metadata", self.metadata
)
shutil.copyfile(f.name, path)
# Must reset the importer back to `self` if it existed
# so that you can save the model again!
if "importer" in kwargs:
self.importer = kwargs["importer"][0]
return path
@classmethod
def _load_package(cls, path, package_name=None):
package_name = cls.__name__ if package_name is None else package_name
resource_name = f"{package_name}.pth"
imp = torch.package.PackageImporter(path)
model = imp.load_pickle(package_name, resource_name, "cpu")
try:
model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
except: # pragma: no cover
pass
model.importer = imp
return model
def save_to_folder(
self,
folder: typing.Union[str, Path],
extra_data: dict = None,
package: bool = True,
):
"""Dumps a model into a folder, as both a package
and as weights, as well as anything specified in
``extra_data``. ``extra_data`` is a dictionary of other
pickleable files, with the keys being the paths
to save them in. The model is saved under a subfolder
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
if the model name was ``Generator``).
>>> with tempfile.TemporaryDirectory() as d:
>>> extra_data = {
>>> "optimizer.pth": optimizer.state_dict()
>>> }
>>> model.save_to_folder(d, extra_data)
>>> Model.load_from_folder(d)
Parameters
----------
folder : typing.Union[str, Path]
_description_
extra_data : dict, optional
_description_, by default None
Returns
-------
str
Path to folder
"""
extra_data = {} if extra_data is None else extra_data
model_name = type(self).__name__.lower()
target_base = Path(f"{folder}/{model_name}/")
target_base.mkdir(exist_ok=True, parents=True)
if package:
package_path = target_base / f"package.pth"
self.save(package_path)
weights_path = target_base / f"weights.pth"
self.save(weights_path, package=False)
for path, obj in extra_data.items():
torch.save(obj, target_base / path)
return target_base
@classmethod
def load_from_folder(
cls,
folder: typing.Union[str, Path],
package: bool = True,
strict: bool = False,
**kwargs,
):
"""Loads the model from a folder generated by
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
Like that function, this one looks for a subfolder that has
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
model name was ``Generator``).
Parameters
----------
folder : typing.Union[str, Path]
_description_
package : bool, optional
Whether to use ``torch.package`` to load the model,
loading the model from ``package.pth``.
strict : bool, optional
Ignore unmatched keys, by default False
Returns
-------
tuple
tuple of model and extra data as saved by
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
"""
folder = Path(folder) / cls.__name__.lower()
model_pth = "package.pth" if package else "weights.pth"
model_pth = folder / model_pth
model = cls.load(model_pth, strict=strict)
extra_data = {}
excluded = ["package.pth", "weights.pth"]
files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
for f in files:
extra_data[f.name] = torch.load(f, **kwargs)
return model, extra_data
|