Spaces:
Sleeping
Sleeping
# Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py | |
# https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py | |
# https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2 | |
# https://github.com/PyTorchLightning/pytorch-lightning/issues/8100 | |
from typing import Dict, Any | |
from pytorch_lightning import Callback, Trainer | |
from pytorch_lightning.utilities import rank_zero_only | |
from pytorch_lightning.utilities.parsing import AttributeDict | |
from pytorch_lightning.utilities.types import STEP_OUTPUT | |
from src.utils.ema import ExponentialMovingAverage | |
class EMACallback(Callback): | |
"""TD [2021-08-31]: saving and loading from checkpoint should work. | |
""" | |
def __init__(self, decay: float, use_num_updates: bool = True): | |
""" | |
decay: The exponential decay. | |
use_num_updates: Whether to use number of updates when computing | |
averages. | |
""" | |
super().__init__() | |
self.decay = decay | |
self.use_num_updates = use_num_updates | |
self.ema = None | |
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): | |
# It's possible that we already loaded EMA from the checkpoint | |
if self.ema is None: | |
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], | |
decay=self.decay, use_num_updates=self.use_num_updates) | |
# Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it | |
# We only want to update when parameters are changing. | |
# Because of gradient accumulation, this doesn't happen every training step. | |
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11688 | |
def on_train_batch_end( | |
self, | |
trainer: "pl.Trainer", | |
pl_module: "pl.LightningModule", | |
outputs: STEP_OUTPUT, | |
batch: Any, | |
batch_idx: int, | |
) -> None: | |
if (batch_idx + 1) % trainer.accumulate_grad_batches == 0: | |
self.ema.update() | |
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
# During the initial validation we don't have self.ema yet | |
if self.ema is not None: | |
self.ema.store() | |
self.ema.copy_to() | |
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
if self.ema is not None: | |
self.ema.restore() | |
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
if self.ema is not None: | |
self.ema.store() | |
self.ema.copy_to() | |
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
if self.ema is not None: | |
self.ema.restore() | |
def on_save_checkpoint( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] | |
) -> Dict[str, Any]: | |
return self.ema.state_dict() | |
def on_load_checkpoint( | |
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", | |
checkpoint: Dict[str, Any] | |
) -> None: | |
if self.ema is None: | |
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad], | |
decay=self.decay, use_num_updates=self.use_num_updates) | |
self.ema.load_state_dict(checkpoint) | |