Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
# Inspired by https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/stochastic_weight_avg.py
# https://github.com/PyTorchLightning/Lightning-Bolts/blob/master/pl_bolts/callbacks/byol_updates.py
# https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488/2
# https://github.com/PyTorchLightning/pytorch-lightning/issues/8100
from typing import Dict, Any
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
from src.utils.ema import ExponentialMovingAverage
class EMACallback(Callback):
"""TD [2021-08-31]: saving and loading from checkpoint should work.
"""
def __init__(self, decay: float, use_num_updates: bool = True):
"""
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
super().__init__()
self.decay = decay
self.use_num_updates = use_num_updates
self.ema = None
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
# It's possible that we already loaded EMA from the checkpoint
if self.ema is None:
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
decay=self.decay, use_num_updates=self.use_num_updates)
# Ideally we want on_after_optimizer_step but pytorch-lightning doesn't have it
# We only want to update when parameters are changing.
# Because of gradient accumulation, this doesn't happen every training step.
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11688
def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
if (batch_idx + 1) % trainer.accumulate_grad_batches == 0:
self.ema.update()
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# During the initial validation we don't have self.ema yet
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.restore()
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.ema is not None:
self.ema.restore()
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[str, Any]:
return self.ema.state_dict()
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
if self.ema is None:
self.ema = ExponentialMovingAverage([p for p in pl_module.parameters() if p.requires_grad],
decay=self.decay, use_num_updates=self.use_num_updates)
self.ema.load_state_dict(checkpoint)