# Copied from https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py from __future__ import division from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy import contextlib import torch def to_float_maybe(x): return x.float() if x.dtype in [torch.float16, torch.bfloat16] else x # Partially based on: # https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. Args: parameters: Iterable of `torch.nn.Parameter` (typically from `model.parameters()`). decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float, use_num_updates: bool = True ): if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None parameters = list(parameters) self.shadow_params = [to_float_maybe(p.clone().detach()) for p in parameters if p.requires_grad] self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage # is kept, no references to the model or its parameters will be # maintained, and the model will be cleaned up. self._params_refs = [weakref.ref(p) for p in parameters] def _get_parameters( self, parameters: Optional[Iterable[torch.nn.Parameter]] ) -> Iterable[torch.nn.Parameter]: if parameters is None: parameters = [p() for p in self._params_refs] if any(p is None for p in parameters): raise ValueError( "(One of) the parameters with which this " "ExponentialMovingAverage " "was initialized no longer exists (was garbage collected);" " please either provide `parameters` explicitly or keep " "the model to which they belong from being garbage " "collected." ) return parameters else: parameters = list(parameters) if len(parameters) != len(self.shadow_params): raise ValueError( "Number of parameters passed as argument is different " "from number of shadow parameters maintained by this " "ExponentialMovingAverage" ) return parameters def update( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min( decay, (1 + self.num_updates) / (10 + self.num_updates) ) one_minus_decay = 1.0 - decay if parameters[0].device != self.shadow_params[0].device: self.to(device=parameters[0].device) with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): torch.lerp(s_param, param.to(dtype=s_param.dtype), one_minus_decay, out=s_param) def copy_to( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) def store( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. If `None`, the parameters of with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.collected_params = [ param.clone() for param in parameters if param.requires_grad ] def restore( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.collected_params is None: raise RuntimeError( "This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`" ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) @contextlib.contextmanager def average_parameters( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None ): r""" Context manager for validation/inference with averaged parameters. Equivalent to: ema.store() ema.copy_to() try: ... finally: ema.restore() Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.store(parameters) self.copy_to(parameters) try: yield finally: self.restore(parameters) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] if self.collected_params is not None: self.collected_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.collected_params ] return def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "num_updates": self.num_updates, "shadow_params": self.shadow_params, "collected_params": self.collected_params } def load_state_dict(self, state_dict: dict) -> None: r"""Loads the ExponentialMovingAverage state. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict["decay"] if self.decay < 0.0 or self.decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" self.collected_params = state_dict["collected_params"] if self.collected_params is not None: assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" assert len(self.collected_params) == len(self.shadow_params), \ "collected_params and shadow_params had different lengths" if len(self.shadow_params) == len(self._params_refs): # Consistent with torch.optim.Optimizer, cast things to consistent # device and dtype with the parameters params = [p() for p in self._params_refs] # If parameters have been garbage collected, just load the state # we were given without change. if not any(p is None for p in params): # ^ parameter references are still good for i, p in enumerate(params): self.shadow_params[i] = to_float_maybe(self.shadow_params[i].to( device=p.device, dtype=p.dtype )) if self.collected_params is not None: self.collected_params[i] = self.collected_params[i].to( device=p.device, dtype=p.dtype ) else: raise ValueError( "Tried to `load_state_dict()` with the wrong number of " "parameters in the saved state." )