Spaces:
Running
Running
File size: 2,635 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
class DynamicLossScaler(object):
def __init__(
self,
init_scale=2.0**15,
scale_factor=2.0,
scale_window=2000,
tolerance=0.0,
threshold=None,
min_loss_scale=1e-4,
):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self.tolerance = tolerance
self.threshold = threshold
self._iter = 0
self._last_overflow_iter = -1
self._last_rescale_iter = -1
self._overflows_since_rescale = 0
self.min_loss_scale = min_loss_scale
def scale(self, outputs):
return self.loss_scale * outputs
def update(self):
if (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._last_rescale_iter = self._iter
self._iter += 1
def _decrease_loss_scale(self):
self.loss_scale /= self.scale_factor
if self.threshold is not None:
self.loss_scale = max(self.loss_scale, self.threshold)
def check_overflow(self, grad_norm):
# detect inf and nan
if grad_norm == float("inf") or grad_norm != grad_norm:
# overflow has occured
prev_scale = self.loss_scale
iter_since_rescale = self._iter - self._last_rescale_iter
self._last_overflow_iter = self._iter
self._overflows_since_rescale += 1
pct_overflow = self._overflows_since_rescale / float(iter_since_rescale)
if pct_overflow >= self.tolerance:
self._decrease_loss_scale()
self._last_rescale_iter = self._iter
self._overflows_since_rescale = 0
if self.loss_scale <= self.min_loss_scale:
# Use FloatingPointError as an uncommon error that parent
# functions can safely catch to stop training.
self.loss_scale = prev_scale
raise FloatingPointError(
(
"Minimum loss scale reached ({}). Your loss is probably exploding. "
"Try lowering the learning rate, using gradient clipping or "
"increasing the batch size."
).format(self.min_loss_scale)
)
self._iter += 1
raise OverflowError("setting loss scale to: " + str(self.loss_scale))
|