Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
class DynamicLossScaler(object): | |
def __init__( | |
self, | |
init_scale=2.0**15, | |
scale_factor=2.0, | |
scale_window=2000, | |
tolerance=0.0, | |
threshold=None, | |
min_loss_scale=1e-4, | |
): | |
self.loss_scale = init_scale | |
self.scale_factor = scale_factor | |
self.scale_window = scale_window | |
self.tolerance = tolerance | |
self.threshold = threshold | |
self._iter = 0 | |
self._last_overflow_iter = -1 | |
self._last_rescale_iter = -1 | |
self._overflows_since_rescale = 0 | |
self.min_loss_scale = min_loss_scale | |
def scale(self, outputs): | |
return self.loss_scale * outputs | |
def update(self): | |
if (self._iter - self._last_overflow_iter) % self.scale_window == 0: | |
self.loss_scale *= self.scale_factor | |
self._last_rescale_iter = self._iter | |
self._iter += 1 | |
def _decrease_loss_scale(self): | |
self.loss_scale /= self.scale_factor | |
if self.threshold is not None: | |
self.loss_scale = max(self.loss_scale, self.threshold) | |
def check_overflow(self, grad_norm): | |
# detect inf and nan | |
if grad_norm == float("inf") or grad_norm != grad_norm: | |
# overflow has occured | |
prev_scale = self.loss_scale | |
iter_since_rescale = self._iter - self._last_rescale_iter | |
self._last_overflow_iter = self._iter | |
self._overflows_since_rescale += 1 | |
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) | |
if pct_overflow >= self.tolerance: | |
self._decrease_loss_scale() | |
self._last_rescale_iter = self._iter | |
self._overflows_since_rescale = 0 | |
if self.loss_scale <= self.min_loss_scale: | |
# Use FloatingPointError as an uncommon error that parent | |
# functions can safely catch to stop training. | |
self.loss_scale = prev_scale | |
raise FloatingPointError( | |
( | |
"Minimum loss scale reached ({}). Your loss is probably exploding. " | |
"Try lowering the learning rate, using gradient clipping or " | |
"increasing the batch size." | |
).format(self.min_loss_scale) | |
) | |
self._iter += 1 | |
raise OverflowError("setting loss scale to: " + str(self.loss_scale)) | |