Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
This module has the EMA class used to store a copy of the exponentially decayed | |
model params. | |
Typical usage of EMA class involves initializing an object using an existing | |
model (random or from a seed model) and setting the config like ema_decay, | |
ema_start_update which determine how the EMA model is updated. After every | |
update of the model i.e. at the end of the train_step, the EMA should be updated | |
by passing the new model to the EMA.step function. The EMA model state dict | |
can be stored in the extra state under the key of "ema" and dumped | |
into a checkpoint and loaded. The EMA object can be passed to tasks | |
by setting task.uses_ema property. | |
EMA is a smoothed/ensemble model which might have better performance | |
when used for inference or further fine-tuning. EMA class has a | |
reverse function to load the EMA params into a model and use it | |
like a regular model. | |
This implementation is used for trainer-level ema tracking. For EMA tracking | |
inside the model, please use fairseq/modules/ema_module.py instead. | |
""" | |
import copy | |
import logging | |
import torch | |
from fairseq import checkpoint_utils | |
class EMA(object): | |
"""Exponential Moving Average of Fairseq Models | |
EMA keeps a copy of the exponentially decayed model params. | |
The set of params should include both gradient-descent and | |
non-gradient descent params, such as batch mean/var and buffers. | |
This is a modified implementation of | |
the open source code in https://github.com/zhawe01/fairseq-gec.git, | |
and internal source code in | |
fbcode/mobile-vision/projects/classification_pytorch/lib/utils/model_ema.py. | |
Similar to TF EMA. | |
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage. | |
EMA provides a averaged and smoothed set of model weights, and has been shown to | |
improve vision models. EMA class does all necessary functions to update, reload, | |
or init EMA methods. | |
EMA object is initialized from an arbitrary model. By default, it is stored in | |
the same device (unless device specified at initialization) and with the | |
same precision as the model (unless ema_fp32 is True). ema_fp32 is recommended. | |
This stores the EMA parameters in fp32 only for the EMA update step, and | |
is used at the default precision otherwise. | |
EMA is usually enabled using EMAConfig with store_ema=True. Some important | |
parameters to configure EMA are | |
1) ema_decay - The decay of EMA | |
2) ema_update_freq - EMA is updated every this many model updates. | |
3) ema_start_update - Start EMA update after this many model updates [default 0] | |
Key methods: | |
1) step - One update of EMA using new model | |
2) restore - Update EMA from a state dict | |
3) reverse - Load EMA into a model | |
4) get_decay, _set_decay - Used to get or set the decay. Note _set_decay is | |
called from step. | |
5) build_fp32_params - Used to initialize or update the fp32 copy of EMA params. | |
Note this is enabled only when ema_fp32=True | |
""" | |
def __init__(self, model, config, device=None, skip_keys=None): | |
""" | |
@param model model to initialize the EMA with | |
@param config EMAConfig object with configuration like | |
ema_decay, ema_update_freq, ema_fp32 | |
@param device If provided, copy EMA to this device (e.g. gpu). | |
Otherwise EMA is in the same device as the model. | |
""" | |
self.decay = config.ema_decay | |
self.model = copy.deepcopy(model) | |
self.model.requires_grad_(False) | |
self.config = config | |
self.skip_keys = skip_keys or set() | |
self.fp32_params = {} | |
if self.config.ema_seed_model is not None: | |
state = checkpoint_utils.load_ema_from_checkpoint( | |
self.config.ema_seed_model | |
) | |
self.model.load_state_dict(state["model"], strict=True) | |
if device is not None: | |
logging.info(f"Copying EMA model to device {device}") | |
self.model = self.model.to(device=device) | |
if self.config.ema_fp32: | |
self.build_fp32_params() | |
self.update_freq_counter = 0 | |
def get_model(self): | |
return self.model | |
def build_fp32_params(self, state_dict=None): | |
""" | |
Store a copy of the EMA params in fp32. | |
If state dict is passed, the EMA params is copied from | |
the provided state dict. Otherwise, it is copied from the | |
current EMA model parameters. | |
""" | |
if not self.config.ema_fp32: | |
raise RuntimeError( | |
"build_fp32_params should not be called if ema_fp32=False. " | |
"Use ema_fp32=True if this is really intended." | |
) | |
if state_dict is None: | |
state_dict = self.model.state_dict() | |
def _to_float(t): | |
return t.float() if torch.is_floating_point(t) else t | |
for param_key in state_dict: | |
if param_key in self.fp32_params: | |
self.fp32_params[param_key].copy_(state_dict[param_key]) | |
else: | |
self.fp32_params[param_key] = _to_float(state_dict[param_key]) | |
def restore(self, state_dict, build_fp32_params=False): | |
"""Load data from a model spec into EMA model""" | |
self.model.load_state_dict(state_dict, strict=False) | |
if build_fp32_params: | |
self.build_fp32_params(state_dict) | |
def _set_decay(self, decay): | |
self.decay = decay | |
def get_decay(self): | |
return self.decay | |
def _step_internal(self, new_model, updates=None): | |
"""One update of the EMA model based on new model weights""" | |
decay = self.decay | |
ema_state_dict = {} | |
ema_params = ( | |
self.fp32_params if self.config.ema_fp32 else self.model.state_dict() | |
) | |
for key, param in new_model.state_dict().items(): | |
if isinstance(param, dict): | |
continue | |
try: | |
ema_param = ema_params[key] | |
except KeyError: | |
ema_param = ( | |
param.float().clone() if param.ndim == 1 else copy.deepcopy(param) | |
) | |
if param.shape != ema_param.shape: | |
raise ValueError( | |
"incompatible tensor shapes between model param and ema param" | |
+ "{} vs. {}".format(param.shape, ema_param.shape) | |
) | |
if "version" in key: | |
# Do not decay a model.version pytorch param | |
continue | |
if key in self.skip_keys: | |
ema_param = param.to(dtype=ema_param.dtype).clone() | |
else: | |
ema_param.mul_(decay) | |
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) | |
ema_state_dict[key] = ema_param | |
self.restore(ema_state_dict, build_fp32_params=False) | |
def step(self, new_model, updates=None): | |
""" | |
One update of EMA which is done every self.config.ema_update_freq | |
updates of the model. | |
@param updates The current number of model updates done. | |
Decay is set of 0 if model updates < ema_start_update, which means | |
the model will be simply copied over to the EMA. | |
When model updates >= ema_start_updates, then EMA is updated with | |
a decay of self.config.ema_decay. | |
""" | |
if updates is not None: | |
self._set_decay( | |
0 if updates < self.config.ema_start_update else self.config.ema_decay | |
) | |
if self.config.ema_update_freq > 1: | |
self.update_freq_counter += 1 | |
if self.update_freq_counter >= self.config.ema_update_freq: | |
self._step_internal(new_model, updates) | |
self.update_freq_counter = 0 | |
else: | |
self._step_internal(new_model, updates) | |
def reverse(self, model): | |
""" | |
Load the model parameters from EMA model. | |
Useful for inference or fine-tuning from the EMA model. | |
""" | |
d = self.model.state_dict() | |
if "_ema" in d: | |
del d["_ema"] | |
model.load_state_dict(d, strict=False) | |
return model | |