Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from collections import defaultdict | |
from itertools import chain | |
import torch | |
from omegaconf import DictConfig | |
from fairseq import optim | |
from .dynamic_loss_scaler import DynamicLossScaler | |
class _FP16OptimizerMixin(object): | |
def __init__(self, *args, **kwargs): | |
# forward __init__ call to the next class in mro(method resolution order) | |
super().__init__(*args, **kwargs) | |
self._multiply_factor = 1.0 | |
def has_flat_params(self): | |
return torch.is_tensor(self.fp32_params) or ( | |
isinstance(self.fp32_params, dict) | |
and all(torch.is_tensor(t) for t in self.fp32_params.values()) | |
) | |
def build_fp32_params(cls, args, params, flatten=True): | |
# create FP32 copy of parameters and grads | |
if flatten: | |
is_pipeline_parallel = getattr( | |
args, "pipeline_model_parallel", False | |
) and getattr(args, "distributed_no_spawn", False) | |
total_param_size = sum(p.data.numel() for p in params) | |
devices = [torch.cuda.current_device()] | |
if is_pipeline_parallel: | |
devices = list(set(args.pipeline_devices)) | |
fp32_params = {} | |
for device in devices: | |
if is_pipeline_parallel: | |
device_param_size = sum( | |
p.data.numel() for p in params if p.device.index == device | |
) | |
device_params = [p for p in params if p.device.index == device] | |
else: | |
device_param_size = total_param_size | |
device_params = params | |
fp32_params[device] = ( | |
device_params[0].new(0).float().new(device_param_size) | |
) | |
offset = 0 | |
for p in device_params: | |
numel = p.data.numel() | |
fp32_params[device][offset : offset + numel].copy_(p.data.view(-1)) | |
offset += numel | |
fp32_params[device] = torch.nn.Parameter(fp32_params[device]) | |
fp32_params[device].grad = fp32_params[device].data.new( | |
device_param_size | |
) | |
return fp32_params | |
else: | |
fp32_params = [] | |
for p in params: | |
p32 = torch.nn.Parameter(p.data.float()) | |
if hasattr(p, "expert"): | |
p32.expert = True | |
elif hasattr(p, "base_expert"): | |
p32.base_expert = True | |
p32.grad = torch.zeros_like(p32.data) | |
if hasattr(p, "param_group"): | |
p32.param_group = p.param_group | |
fp32_params.append(p32) | |
return fp32_params | |
def state_dict(self): | |
"""Return the optimizer's state dict.""" | |
state_dict = self.fp32_optimizer.state_dict() | |
if self.scaler is not None: | |
state_dict["loss_scale"] = self.scaler.loss_scale | |
return state_dict | |
def load_state_dict(self, state_dict, optimizer_overrides=None): | |
"""Load an optimizer state dict. | |
In general we should prefer the configuration of the existing optimizer | |
instance (e.g., learning rate) over that found in the state_dict. This | |
allows us to resume training from a checkpoint using a new set of | |
optimizer args. | |
""" | |
if "loss_scale" in state_dict and self.scaler is not None: | |
self.scaler.loss_scale = state_dict["loss_scale"] | |
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) | |
def backward(self, loss): | |
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. | |
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this | |
function additionally dynamically scales the loss to avoid gradient | |
underflow. | |
""" | |
if self.scaler is not None: | |
loss = self.scaler.scale(loss) | |
loss.backward() | |
self._needs_sync = True | |
def _sync_fp16_grads_to_fp32(self): | |
if self._needs_sync: | |
# copy FP16 grads to FP32 | |
if self.has_flat_params: | |
devices = list(self.fp32_params.keys()) | |
device_params_dict = defaultdict(list) | |
for p in self.fp16_params: | |
if p.requires_grad: | |
device_params_dict[p.device.index].append(p) | |
for device in devices: | |
device_params = device_params_dict[device] | |
offset = 0 | |
for p in device_params: | |
grad_data = ( | |
p.grad.data | |
if p.grad is not None | |
else p.data.new_zeros(p.data.shape) | |
) | |
numel = grad_data.numel() | |
self.fp32_params[device].grad.data[ | |
offset : offset + numel | |
].copy_(grad_data.view(-1)) | |
offset += numel | |
else: | |
for p, p32 in zip(self.fp16_params, self.fp32_params): | |
if not p.requires_grad: | |
continue | |
if p.grad is not None: | |
if p32.grad is None: | |
p32.grad = p.grad.data.float() | |
else: | |
p32.grad.data.copy_(p.grad.data) | |
else: | |
p32.grad = torch.zeros_like(p.data, dtype=torch.float) | |
self._needs_sync = False | |
def _sync_fp32_params_to_fp16(self): | |
# copy FP32 params back into FP16 model | |
if self.has_flat_params: | |
devices = list(self.fp32_params.keys()) | |
device_params_dict = defaultdict(list) | |
for p in self.fp16_params: | |
device_params_dict[p.device.index].append(p) | |
for device in devices: | |
device_params = device_params_dict[device] | |
offset = 0 | |
for p in device_params: | |
numel = p.data.numel() | |
p.data.copy_( | |
self.fp32_params[device] | |
.data[offset : offset + numel] | |
.view_as(p.data) | |
) | |
offset += numel | |
else: | |
for p, p32 in zip(self.fp16_params, self.fp32_params): | |
if not p.requires_grad: | |
continue | |
p.data.copy_(p32.data) | |
def _unscale_grads(self): | |
self._sync_fp16_grads_to_fp32() | |
if ( | |
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor | |
# is 1.0). At the same time, we want to avoid the device-to-host | |
# transfer by comparing it to 1.0. Since _multiply_factor starts as | |
# a Python float, we roughly assume that if it's a tensor then it's | |
# probably not =1.0 anymore and we do the multiplication. Otherwise | |
# we can safely check the value without a D2H transfer. | |
torch.is_tensor(self._multiply_factor) | |
or self._multiply_factor != 1.0 | |
): | |
self.fp32_optimizer.multiply_grads(self._multiply_factor) | |
self._multiply_factor = 1.0 | |
def multiply_grads(self, c): | |
"""Multiplies grads by a constant ``c``.""" | |
self._multiply_factor *= c | |
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): | |
"""Clips gradient norm and updates dynamic loss scaler.""" | |
self._sync_fp16_grads_to_fp32() | |
grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( | |
0, aggregate_norm_fn | |
) | |
if self.scaler is not None: | |
if grad_norm > max_norm > 0.0: | |
self._multiply_factor *= max_norm / grad_norm | |
self.scaler.check_overflow(grad_norm) | |
elif max_norm > 0.0: | |
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) | |
self._multiply_factor *= clip_coef | |
return grad_norm | |
def step(self, closure=None, groups=None): | |
"""Performs a single optimization step.""" | |
self._sync_fp16_grads_to_fp32() | |
if getattr(self, "supports_step_with_scale", False): | |
self.fp32_optimizer.step( | |
closure, scale=(1.0 / self._multiply_factor), groups=groups | |
) | |
else: | |
self._unscale_grads() | |
self.fp32_optimizer.step(closure, groups=groups) | |
if self.scaler is not None: | |
self.scaler.update() | |
self._sync_fp32_params_to_fp16() | |
def zero_grad(self): | |
"""Clears the gradients of all optimized parameters.""" | |
for p in self.fp16_params: | |
p.grad = None | |
if self.has_flat_params: | |
if torch.is_tensor(self.fp32_params): | |
self.fp32_params.grad.zero_() | |
elif isinstance(self.fp32_params, dict): | |
for fp32_params in self.fp32_params.values(): | |
fp32_params.grad.zero_() | |
else: | |
raise RuntimeError("self.fp32_params must be a tensor or dict") | |
else: | |
for p32 in self.fp32_params: | |
if p32.grad is not None: | |
p32.grad.zero_() | |
self._needs_sync = False | |
if self.scaler is not None: | |
self._multiply_factor = 1.0 / float(self.scaler.loss_scale) | |
class FP16Optimizer(_FP16OptimizerMixin, optim.FairseqOptimizer): | |
""" | |
Wrap an *optimizer* to support FP16 (mixed precision) training. | |
""" | |
def __init__(self, cfg: DictConfig, params, fp32_optimizer, fp32_params, **kwargs): | |
super().__init__(cfg.optimizer) | |
self.fp16_params = params | |
self.fp32_optimizer = fp32_optimizer | |
self.fp32_params = fp32_params | |
if getattr(cfg.common, "fp16_scale_window", None) is None: | |
if len(cfg.optimization.update_freq) > 1: | |
raise ValueError( | |
"--fp16-scale-window must be given explicitly when using a " | |
"custom --update-freq schedule" | |
) | |
data_parallel_size = int( | |
cfg.distributed_training.distributed_world_size | |
/ cfg.common.model_parallel_size | |
) | |
scale_window = int( | |
2**14 / data_parallel_size / cfg.optimization.update_freq[0] | |
) | |
else: | |
scale_window = cfg.common.fp16_scale_window | |
if not getattr(cfg.common, "bf16", False): | |
self.scaler = DynamicLossScaler( | |
init_scale=cfg.common.fp16_init_scale, | |
scale_window=scale_window, | |
tolerance=cfg.common.fp16_scale_tolerance, | |
threshold=cfg.common.threshold_loss_scale, | |
min_loss_scale=cfg.common.min_loss_scale, | |
) | |
else: | |
# disable loss scaling for bfloat16 | |
self.scaler = None | |
def build_optimizer(cls, cfg: DictConfig, params, **kwargs): | |
""" | |
Args: | |
cfg (omegaconf.DictConfig): fairseq args | |
params (iterable): iterable of parameters to optimize | |
""" | |
flatten = not getattr(cfg.common, "fp16_no_flatten_grads", False) | |
if getattr(cfg.common, "bf16", False): | |
flatten = False # mixed precision is faster on TPUs without flat grads | |
fp32_params = cls.build_fp32_params(cfg.optimizer, params, flatten=flatten) | |
if flatten: | |
fp32_optimizer = optim.build_optimizer(cfg.optimizer, [fp32_params]) | |
else: | |
fp32_optimizer = optim.build_optimizer(cfg.optimizer, fp32_params) | |
if flatten and not fp32_optimizer.supports_flat_params: | |
raise RuntimeError( | |
f"chosen optimizer {fp32_optimizer.__class__.__name__} does not support flat params, please set --fp16-no-flatten-grads" | |
) | |
return cls(cfg, params, fp32_optimizer, fp32_params, **kwargs) | |
def optimizer(self): | |
return self.fp32_optimizer.optimizer | |
def optimizer(self, optimizer): | |
self.fp32_optimizer.optimizer = optimizer | |
def lr_scheduler(self): | |
return getattr(self.fp32_optimizer, "lr_scheduler", None) | |
def optimizer_config(self): | |
return self.fp32_optimizer.optimizer_config | |
def get_lr(self): | |
return self.fp32_optimizer.get_lr() | |
def set_lr(self, lr): | |
self.fp32_optimizer.set_lr(lr) | |
def all_reduce_grads(self, module): | |
self.fp32_optimizer.all_reduce_grads(module) | |
def supports_flat_params(self): | |
return self.fp32_optimizer.supports_flat_params | |
class _MemoryEfficientFP16OptimizerMixin(object): | |
def __init__(self, *args, **kwargs): | |
# forward __init__ call to the next class in MRO (method resolution order) | |
super().__init__(*args, **kwargs) | |
self._multiply_factor = 1.0 | |
def has_flat_params(self): | |
return False | |
def state_dict(self): | |
"""Return the optimizer's state dict.""" | |
state_dict = self.wrapped_optimizer.state_dict() | |
if self.scaler is not None: | |
state_dict["loss_scale"] = self.scaler.loss_scale | |
return state_dict | |
def load_state_dict(self, state_dict, optimizer_overrides=None): | |
"""Load an optimizer state dict. | |
In general we should prefer the configuration of the existing optimizer | |
instance (e.g., learning rate) over that found in the state_dict. This | |
allows us to resume training from a checkpoint using a new set of | |
optimizer args. | |
""" | |
if "loss_scale" in state_dict and self.scaler is not None: | |
self.scaler.loss_scale = state_dict["loss_scale"] | |
self.wrapped_optimizer.load_state_dict(state_dict, optimizer_overrides) | |
# Hack: PyTorch automatically casts the optimizer state to match the | |
# type of the current parameters. But with --memory-efficient-fp16 the | |
# params are FP16 while the optimizer state is FP32 and we don't want | |
# to cast. A workaround is to manually copy back the original state | |
# after the optimizer has been loaded. | |
if not getattr(self.optimizer, "disable_mem_eff_fp16_loading_hack", False): | |
groups = self.optimizer.param_groups | |
saved_groups = state_dict["param_groups"] | |
id_map = { | |
old_id: p | |
for old_id, p in zip( | |
chain(*(g["params"] for g in saved_groups)), | |
chain(*(g["params"] for g in groups)), | |
) | |
} | |
for k, v in state_dict["state"].items(): | |
if k in id_map: | |
param = id_map[k] | |
self.optimizer.state[param] = v | |
def backward(self, loss): | |
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves. | |
Compared to :func:`fairseq.optim.FairseqOptimizer.backward`, this | |
function additionally dynamically scales the loss to avoid gradient | |
underflow. | |
""" | |
if self.scaler is not None: | |
loss = self.scaler.scale(loss) | |
loss.backward() | |
def _unscale_grads(self): | |
if ( | |
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor | |
# is 1.0). At the same time, we want to avoid the device-to-host | |
# transfer by comparing it to 1.0. Since _multiply_factor starts as | |
# a Python float, we roughly assume that if it's a tensor then it's | |
# probably not =1.0 anymore and we do the multiplication. Otherwise | |
# we can safely check the value without a D2H transfer. | |
torch.is_tensor(self._multiply_factor) | |
or self._multiply_factor != 1.0 | |
): | |
self.wrapped_optimizer.multiply_grads(self._multiply_factor) | |
self._multiply_factor = 1.0 | |
def multiply_grads(self, c): | |
"""Multiplies grads by a constant *c*.""" | |
self._multiply_factor *= c | |
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): | |
"""Clips gradient norm and updates dynamic loss scaler.""" | |
max_norm = float(max_norm) | |
grad_norm = self._multiply_factor * self.wrapped_optimizer.clip_grad_norm( | |
0, aggregate_norm_fn | |
) | |
if self.scaler is not None: | |
grad_norm_cpu = float(grad_norm) | |
if grad_norm_cpu > max_norm > 0.0: | |
self._multiply_factor *= max_norm / grad_norm_cpu | |
# detect overflow and adjust loss scale | |
self.scaler.check_overflow(grad_norm_cpu) | |
elif max_norm > 0.0: | |
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) | |
self._multiply_factor *= clip_coef | |
return grad_norm | |
def step(self, closure=None, groups=None): | |
"""Performs a single optimization step.""" | |
if getattr(self, "supports_step_with_scale", False): | |
# NOTE(msb) optimizer divides by scale factor | |
self.wrapped_optimizer.step( | |
closure, scale=(1.0 / self._multiply_factor), groups=groups | |
) | |
else: | |
self._unscale_grads() | |
self.wrapped_optimizer.step(closure, groups=groups) | |
if self.scaler is not None: | |
self.scaler.update() | |
def zero_grad(self): | |
"""Clears the gradients of all optimized parameters.""" | |
self.wrapped_optimizer.zero_grad() | |
if self.scaler is not None: | |
self._multiply_factor = 1.0 / float(self.scaler.loss_scale) | |
else: | |
self._multiply_factor = 1.0 | |
def supports_flat_params(self): | |
return self.wrapped_optimizer.supports_flat_params | |
class MemoryEfficientFP16Optimizer( | |
_MemoryEfficientFP16OptimizerMixin, optim.FairseqOptimizer | |
): | |
""" | |
Wrap an *optimizer* to support FP16 (mixed precision) training. | |
Compared to :class:`fairseq.optim.FP16Optimizer`, this version does not | |
maintain an FP32 copy of the model. We instead expect the optimizer to | |
convert the gradients to FP32 internally and sync the results back to the | |
FP16 model params. This significantly reduces memory usage but slightly | |
increases the time spent in the optimizer. | |
Since this wrapper depends on specific functionality in the wrapped | |
optimizer (i.e., on-the-fly conversion of grads to FP32), only certain | |
optimizers can be wrapped. This is determined by the | |
*supports_memory_efficient_fp16* property. | |
""" | |
def __init__( | |
self, cfg: DictConfig, params, optimizer, allow_unsupported=False, **kwargs | |
): | |
if not allow_unsupported and not optimizer.supports_memory_efficient_fp16: | |
raise ValueError( | |
"Unsupported optimizer: {}".format(optimizer.__class__.__name__) | |
) | |
super().__init__(getattr(cfg, "optimizer", None)) | |
self.wrapped_optimizer = optimizer | |
if getattr(cfg.common, "fp16_scale_window", None) is None: | |
if len(cfg.optimization.update_freq) > 1: | |
raise ValueError( | |
"--fp16-scale-window must be given explicitly when using a " | |
"custom --update-freq schedule" | |
) | |
data_parallel_size = int( | |
cfg.distributed_training.distributed_world_size | |
/ cfg.common.model_parallel_size | |
) | |
scale_window = int( | |
2**14 / data_parallel_size / cfg.optimization.update_freq[0] | |
) | |
else: | |
scale_window = cfg.common.fp16_scale_window | |
if not getattr(cfg.common, "bf16", False): | |
self.scaler = DynamicLossScaler( | |
init_scale=cfg.common.fp16_init_scale, | |
scale_window=scale_window, | |
tolerance=cfg.common.fp16_scale_tolerance, | |
threshold=cfg.common.threshold_loss_scale, | |
min_loss_scale=cfg.common.min_loss_scale, | |
) | |
else: | |
# disable loss scaling for bfloat16 | |
self.scaler = None | |
def build_optimizer(cls, cfg: DictConfig, params, **kwargs): | |
""" | |
Args: | |
args (argparse.Namespace): fairseq args | |
params (iterable): iterable of parameters to optimize | |
""" | |
fp16_optimizer = optim.build_optimizer(cfg.optimizer, params) | |
return cls(cfg, params, fp16_optimizer, **kwargs) | |
def optimizer(self): | |
return self.wrapped_optimizer.optimizer | |
def optimizer(self, optimizer): | |
self.wrapped_optimizer.optimizer = optimizer | |
def optimizer_config(self): | |
return self.wrapped_optimizer.optimizer_config | |
def lr_scheduler(self): | |
return getattr(self.wrapped_optimizer, "lr_scheduler", None) | |
def get_lr(self): | |
return self.wrapped_optimizer.get_lr() | |
def set_lr(self, lr): | |
self.wrapped_optimizer.set_lr(lr) | |
def all_reduce_grads(self, module): | |
self.wrapped_optimizer.all_reduce_grads(module) | |