Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from fairseq import utils | |
from fairseq.dataclass.utils import gen_parser_from_dataclass | |
class FairseqOptimizer(object): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
def add_args(cls, parser): | |
"""Add optimizer-specific arguments to the parser.""" | |
dc = getattr(cls, "__dataclass", None) | |
if dc is not None: | |
gen_parser_from_dataclass(parser, dc()) | |
def optimizer(self): | |
"""Return a torch.optim.optimizer.Optimizer instance.""" | |
if not hasattr(self, "_optimizer"): | |
raise NotImplementedError | |
if not isinstance(self._optimizer, torch.optim.Optimizer): | |
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") | |
return self._optimizer | |
def optimizer(self, optimizer): | |
"""Reset optimizer instance.""" | |
if not hasattr(self, "_optimizer"): | |
raise NotImplementedError | |
if not isinstance(self._optimizer, torch.optim.Optimizer): | |
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") | |
self._optimizer = optimizer | |
def optimizer_config(self): | |
""" | |
Return a kwarg dictionary that will be used to override optimizer | |
args stored in checkpoints. This allows us to load a checkpoint and | |
resume training using a different set of optimizer args, e.g., with a | |
different learning rate. | |
""" | |
raise NotImplementedError | |
def params(self): | |
"""Return an iterable of the parameters held by the optimizer.""" | |
for param_group in self.param_groups: | |
for p in param_group["params"]: | |
yield p | |
def param_groups(self): | |
return self.optimizer.param_groups | |
def __getstate__(self): | |
return self._optimizer.__getstate__() | |
def get_lr(self): | |
"""Return the current learning rate.""" | |
return self.param_groups[0]["lr"] | |
def set_lr(self, lr): | |
"""Set the learning rate.""" | |
for param_group in self.param_groups: | |
param_group["lr"] = lr | |
def state_dict(self): | |
"""Return the optimizer's state dict.""" | |
return self.optimizer.state_dict() | |
def load_state_dict(self, state_dict, optimizer_overrides=None): | |
"""Load an optimizer state dict. | |
In general we should prefer the configuration of the existing optimizer | |
instance (e.g., learning rate) over that found in the state_dict. This | |
allows us to resume training from a checkpoint using a new set of | |
optimizer args. | |
""" | |
self.optimizer.load_state_dict(state_dict) | |
if optimizer_overrides is not None and len(optimizer_overrides) > 0: | |
# override learning rate, momentum, etc. with latest values | |
for group in self.param_groups: | |
group.update(optimizer_overrides) | |
def backward(self, loss): | |
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" | |
loss.backward() | |
def all_reduce_grads(self, module): | |
"""Manually all-reduce gradients (if required).""" | |
if hasattr(module, "all_reduce_grads"): | |
module.all_reduce_grads() | |
def multiply_grads(self, c): | |
"""Multiplies grads by a constant *c*.""" | |
for p in self.params: | |
if p.grad is not None: | |
if torch.is_tensor(c): | |
c = c.to(p.grad.device) | |
p.grad.data.mul_(c) | |
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): | |
"""Clips gradient norm.""" | |
return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) | |
def step(self, closure=None, scale=1.0, groups=None): | |
"""Performs a single optimization step.""" | |
if self.supports_step_with_scale: | |
if self.supports_groups: | |
self.optimizer.step(closure, scale=scale, groups=groups) | |
else: | |
self.optimizer.step(closure, scale=scale) | |
else: | |
if scale != 1.0: | |
self.multiply_grads(1.0 / scale) | |
if self.supports_groups: | |
self.optimizer.step(closure, groups=groups) | |
else: | |
self.optimizer.step(closure) | |
def zero_grad(self): | |
"""Clears the gradients of all optimized parameters.""" | |
for p in self.params: | |
p.grad = None | |
self.optimizer.zero_grad() | |
def supports_memory_efficient_fp16(self): | |
if hasattr(self.optimizer, "supports_memory_efficient_fp16"): | |
return self.optimizer.supports_memory_efficient_fp16 | |
return False | |
def supports_step_with_scale(self): | |
if hasattr(self.optimizer, "supports_step_with_scale"): | |
return self.optimizer.supports_step_with_scale | |
return False | |
def supports_groups(self): | |
if hasattr(self.optimizer, "supports_groups"): | |
return self.optimizer.supports_groups | |
return False | |
def supports_flat_params(self): | |
""" | |
Whether the optimizer supports collapsing of the model | |
parameters/gradients into a single contiguous Tensor. | |
""" | |
if hasattr(self.optimizer, "supports_flat_params"): | |
return self.optimizer.supports_flat_params | |
return False | |
def average_params(self): | |
pass | |
def broadcast_global_state_dict(self, state_dict): | |
""" | |
Broadcasts a global state dict to all ranks. | |
Useful for optimizers that shard state between ranks. | |
""" | |
if hasattr(self.optimizer, "broadcast_global_state_dict"): | |
return self.optimizer.broadcast_global_state_dict(state_dict) | |
else: | |
return state_dict | |
class LegacyFairseqOptimizer(FairseqOptimizer): | |
def __init__(self, args): | |
self.args = args | |