pikto's picture
Duplicate from algovenus/text-generation-webui
82fea12
raw
history blame contribute delete
No virus
6.45 kB
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
import torch
from .state import AcceleratorState, GradientState
from .utils import DistributedType, honor_type, is_tpu_available
if is_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
def move_to_device(state, device):
if isinstance(state, (list, tuple)):
return honor_type(state, (move_to_device(t, device) for t in state))
elif isinstance(state, dict):
return type(state)({k: move_to_device(v, device) for k, v in state.items()})
elif isinstance(state, torch.Tensor):
return state.to(device)
return state
class AcceleratedOptimizer(torch.optim.Optimizer):
"""
Internal wrapper around a torch optimizer.
Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
accumulation.
Args:
optimizer (`torch.optim.optimizer.Optimizer`):
The optimizer to wrap.
device_placement (`bool`, *optional*, defaults to `True`):
Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
`optimizer` on the right device.
scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*):
The scaler to use in the step function if training with mixed precision.
"""
def __init__(self, optimizer, device_placement=True, scaler=None):
self.optimizer = optimizer
self.scaler = scaler
self.accelerator_state = AcceleratorState()
self.gradient_state = GradientState()
self.device_placement = device_placement
self._is_overflow = False
self._last_scale = None
# Handle device placement
if device_placement:
state_dict = self.optimizer.state_dict()
if self.accelerator_state.distributed_type == DistributedType.TPU:
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
else:
state_dict = move_to_device(state_dict, self.accelerator_state.device)
self.optimizer.load_state_dict(state_dict)
@property
def state(self):
return self.optimizer.state
@state.setter
def state(self, state):
self.optimizer.state = state
@property
def param_groups(self):
return self.optimizer.param_groups
@param_groups.setter
def param_groups(self, param_groups):
self.optimizer.param_groups = param_groups
@property
def defaults(self):
return self.optimizer.defaults
@defaults.setter
def defaults(self, defaults):
self.optimizer.defaults = defaults
def add_param_group(self, param_group):
self.optimizer.add_param_group(param_group)
def load_state_dict(self, state_dict):
if self.accelerator_state.distributed_type == DistributedType.TPU and self.device_placement:
xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
self.optimizer.load_state_dict(state_dict)
def state_dict(self):
return self.optimizer.state_dict()
def zero_grad(self, set_to_none=None):
if self.gradient_state.sync_gradients:
accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
if accept_arg:
if set_to_none is None:
set_to_none = False
self.optimizer.zero_grad(set_to_none=set_to_none)
else:
if set_to_none is not None:
raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
self.optimizer.zero_grad()
def step(self, closure=None):
if self.gradient_state.sync_gradients:
if self.accelerator_state.distributed_type == DistributedType.TPU:
optimizer_args = {"closure": closure} if closure is not None else {}
xm.optimizer_step(self.optimizer, optimizer_args=optimizer_args)
elif self.scaler is not None:
new_scale = False
if self._last_scale is None:
# `get_scale` is an async operation requiring full synchronization
# on CPU and GPUs before finishing. As a result, we store away
# the prior one to reduce the call overhead
self._last_scale = self.scaler.get_scale()
new_scale = True
self.scaler.step(self.optimizer, closure)
self.scaler.update()
scale_after = self.scaler.get_scale()
if not new_scale:
# If we reduced the loss scale, it means the optimizer step was skipped because of gradient overflow.
self._is_overflow = scale_after < self._last_scale
self._last_scale = scale_after
else:
self.optimizer.step(closure)
def _switch_parameters(self, parameters_map):
for param_group in self.optimizer.param_groups:
param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
@property
def is_overflow(self):
"""Whether or not the optimizer step was done, or skipped because of gradient overflow."""
warnings.warn(
"The `is_overflow` property is deprecated and will be removed in version 1.0 of Accelerate use "
"`optimizer.step_was_skipped` instead.",
FutureWarning,
)
return self._is_overflow
@property
def step_was_skipped(self):
"""Whether or not the optimizer step was skipped."""
return self._is_overflow
def __getstate__(self):
return self.__dict__.copy()
def __setstate__(self, state):
self.__dict__.update(state)