Spaces:
Sleeping
Sleeping
# Meant to work with Pytorch's ZeroRedundancyOptimizer | |
from typing import Any, Callable, Dict, List, Optional, Union | |
from pathlib import Path | |
import torch | |
from torch.optim.optimizer import Optimizer | |
from torch.distributed.optim import ZeroRedundancyOptimizer | |
from pytorch_lightning.strategies.ddp import DDPStrategy | |
from pytorch_lightning.core.optimizer import LightningOptimizer | |
try: # pytorch_lightning <= 1.7 | |
from pytorch_lightning.utilities.types import _PATH | |
except ImportError: # pytorch_lightning >= 1.8 | |
try: | |
from lightning_lite.utilities.types import _PATH | |
except ImportError: # pytorch_lightning >= 1.9 | |
from lightning_fabric.utilities.types import _PATH | |
# Copied from Pytorch's ZeroRedundancyOptimizer's state_dict method, but we only get | |
# the local state dict to avoid synchronization across GPUs. | |
# https://github.com/pytorch/pytorch/blob/0c7ca2d97ba5980a2af7dcd6b8106dc915e591cd/torch/distributed/optim/zero_redundancy_optimizer.py#L1131 | |
def get_zero_optimizer_state_dict_local(optimizer, global_rank): | |
optimizer._check_overlap_initialized() | |
# Sync the exposed `param_groups` attributes to the local optimizer in | |
# case they have been updated | |
optimizer._sync_param_groups(optimizer.param_groups, optimizer.optim.param_groups) | |
local_state_dict = optimizer.optim.state_dict() | |
state_dict = super(ZeroRedundancyOptimizer, optimizer).state_dict() | |
# Update the global optimizer state with local state information, | |
# factoring in the translation from local to global indexing | |
rank = global_rank | |
# TODO: recursive copy to device | |
local_param_groups = local_state_dict["param_groups"] | |
global_param_groups = optimizer._partition_parameters()[rank] | |
assert len(local_param_groups) == len(global_param_groups), \ | |
"Mismatch between number of local and global parameter groups" | |
for local_param_group, global_param_group in zip(local_param_groups, global_param_groups): | |
# `local_param_group` stores local indices, while | |
# `global_param_group` stores the tensors directly | |
local_param_indices = local_param_group["params"] | |
global_params = global_param_group["params"] | |
assert len(local_param_indices) == len(global_params), \ | |
"Mismatch between number of local and global parameters in parameter group" | |
for local_param_index, global_param in zip(local_param_indices, global_params): | |
# Update the global parameter state, if any | |
if local_param_index in local_state_dict["state"]: | |
global_param_index = optimizer._param_to_index[global_param] | |
state_dict["state"][global_param_index] = local_state_dict["state"][local_param_index] | |
# Sort the parameters in the state | |
state_dict["state"] = dict(sorted(state_dict["state"].items())) | |
return state_dict | |
class DDPStrategyZero1(DDPStrategy): | |
"""To use ZeroRedundancyOptimizer, we need to shard the optimizer states when | |
saving/loading checkpoints. | |
""" | |
strategy_name = "ddp_zero1" | |
def optimizer_state(self, optimizer: Optimizer) -> Optional[dict]: | |
if isinstance(optimizer, LightningOptimizer): | |
optimizer = optimizer._optimizer | |
if isinstance(optimizer, ZeroRedundancyOptimizer): | |
return get_zero_optimizer_state_dict_local(optimizer, self.global_rank) | |
else: | |
return optimizer.state_dict() | |
def save_checkpoint( | |
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None | |
) -> None: | |
"""Save model/training states as a checkpoint file through state-dump and file-write. | |
Args: | |
checkpoint: dict containing model and trainer state | |
filepath: write-target file's path | |
storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin | |
""" | |
filepath = Path(filepath) | |
filepath.mkdir(parents=True, exist_ok=True) | |
local_optimizer_states = checkpoint.pop('optimizer_states') | |
if self.is_global_zero: | |
self.checkpoint_io.save_checkpoint(checkpoint, filepath / 'model_states.pt', | |
storage_options=storage_options) | |
self.checkpoint_io.save_checkpoint(local_optimizer_states, | |
filepath / f'{self.global_rank:03d}_optim_states.pt', | |
storage_options=storage_options) | |
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: | |
torch.cuda.empty_cache() | |
checkpoint_path = Path(checkpoint_path) | |
if checkpoint_path.is_file(): | |
return super().load_checkpoint(self, str(checkpoint_path)) | |
else: | |
assert checkpoint_path.is_dir() | |
global_states = self.checkpoint_io.load_checkpoint(checkpoint_path / 'model_states.pt') | |
local_optimizer_states = self.checkpoint_io.load_checkpoint(checkpoint_path / f'{self.global_rank:03d}_optim_states.pt') | |
global_states['optimizer_states'] = local_optimizer_states | |
return global_states | |