|
|
|
import os.path as osp |
|
import time |
|
from tempfile import TemporaryDirectory |
|
|
|
import torch |
|
from torch.optim import Optimizer |
|
|
|
import mmcv |
|
from mmcv.parallel import is_module_wrapper |
|
from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict |
|
|
|
try: |
|
import apex |
|
except: |
|
print('apex is not installed') |
|
|
|
|
|
def save_checkpoint(model, filename, optimizer=None, meta=None): |
|
"""Save checkpoint to file. |
|
|
|
The checkpoint will have 4 fields: ``meta``, ``state_dict`` and |
|
``optimizer``, ``amp``. By default ``meta`` will contain version |
|
and time info. |
|
|
|
Args: |
|
model (Module): Module whose params are to be saved. |
|
filename (str): Checkpoint filename. |
|
optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. |
|
meta (dict, optional): Metadata to be saved in checkpoint. |
|
""" |
|
if meta is None: |
|
meta = {} |
|
elif not isinstance(meta, dict): |
|
raise TypeError(f'meta must be a dict or None, but got {type(meta)}') |
|
meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) |
|
|
|
if is_module_wrapper(model): |
|
model = model.module |
|
|
|
if hasattr(model, 'CLASSES') and model.CLASSES is not None: |
|
|
|
meta.update(CLASSES=model.CLASSES) |
|
|
|
checkpoint = { |
|
'meta': meta, |
|
'state_dict': weights_to_cpu(get_state_dict(model)) |
|
} |
|
|
|
if isinstance(optimizer, Optimizer): |
|
checkpoint['optimizer'] = optimizer.state_dict() |
|
elif isinstance(optimizer, dict): |
|
checkpoint['optimizer'] = {} |
|
for name, optim in optimizer.items(): |
|
checkpoint['optimizer'][name] = optim.state_dict() |
|
|
|
|
|
|
|
|
|
if filename.startswith('pavi://'): |
|
try: |
|
from pavi import modelcloud |
|
from pavi.exception import NodeNotFoundError |
|
except ImportError: |
|
raise ImportError( |
|
'Please install pavi to load checkpoint from modelcloud.') |
|
model_path = filename[7:] |
|
root = modelcloud.Folder() |
|
model_dir, model_name = osp.split(model_path) |
|
try: |
|
model = modelcloud.get(model_dir) |
|
except NodeNotFoundError: |
|
model = root.create_training_model(model_dir) |
|
with TemporaryDirectory() as tmp_dir: |
|
checkpoint_file = osp.join(tmp_dir, model_name) |
|
with open(checkpoint_file, 'wb') as f: |
|
torch.save(checkpoint, f) |
|
f.flush() |
|
model.create_file(checkpoint_file, name=model_name) |
|
else: |
|
mmcv.mkdir_or_exist(osp.dirname(filename)) |
|
|
|
with open(filename, 'wb') as f: |
|
torch.save(checkpoint, f) |
|
f.flush() |
|
|