|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class NanDetector: |
|
""" |
|
Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name |
|
""" |
|
|
|
def __init__(self, model, forward=True, backward=True): |
|
self.bhooks = [] |
|
self.fhooks = [] |
|
self.forward = forward |
|
self.backward = backward |
|
self.named_parameters = list(model.named_parameters()) |
|
self.reset() |
|
|
|
for name, mod in model.named_modules(): |
|
mod.__module_name = name |
|
self.add_hooks(mod) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback): |
|
|
|
norm = {} |
|
gradients = {} |
|
for name, param in self.named_parameters: |
|
if param.grad is not None: |
|
grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) |
|
norm[name] = grad_norm.item() |
|
if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): |
|
gradients[name] = param.grad.data |
|
if len(gradients) > 0: |
|
logger.info("Detected nan/inf grad norm, dumping norms...") |
|
logger.info(f"norms: {norm}") |
|
logger.info(f"gradients: {gradients}") |
|
|
|
self.close() |
|
|
|
def add_hooks(self, module): |
|
if self.forward: |
|
self.fhooks.append(module.register_forward_hook(self.fhook_fn)) |
|
if self.backward: |
|
self.bhooks.append(module.register_backward_hook(self.bhook_fn)) |
|
|
|
def reset(self): |
|
self.has_printed_f = False |
|
self.has_printed_b = False |
|
|
|
def _detect(self, tensor, name, backward): |
|
err = None |
|
if ( |
|
torch.is_floating_point(tensor) |
|
|
|
and tensor.numel() >= 2 |
|
): |
|
with torch.no_grad(): |
|
if torch.isnan(tensor).any(): |
|
err = "NaN" |
|
elif torch.isinf(tensor).any(): |
|
err = "Inf" |
|
if err is not None: |
|
err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" |
|
return err |
|
|
|
def _apply(self, module, inp, x, backward): |
|
if torch.is_tensor(x): |
|
if isinstance(inp, tuple) and len(inp) > 0: |
|
inp = inp[0] |
|
err = self._detect(x, module.__module_name, backward) |
|
if err is not None: |
|
if torch.is_tensor(inp) and not backward: |
|
err += ( |
|
f" input max: {inp.max().item()}, input min: {inp.min().item()}" |
|
) |
|
|
|
has_printed_attr = "has_printed_b" if backward else "has_printed_f" |
|
logger.warning(err) |
|
setattr(self, has_printed_attr, True) |
|
elif isinstance(x, dict): |
|
for v in x.values(): |
|
self._apply(module, inp, v, backward) |
|
elif isinstance(x, list) or isinstance(x, tuple): |
|
for v in x: |
|
self._apply(module, inp, v, backward) |
|
|
|
def fhook_fn(self, module, inp, output): |
|
if not self.has_printed_f: |
|
self._apply(module, inp, output, backward=False) |
|
|
|
def bhook_fn(self, module, inp, output): |
|
if not self.has_printed_b: |
|
self._apply(module, inp, output, backward=True) |
|
|
|
def close(self): |
|
for hook in self.fhooks + self.bhooks: |
|
hook.remove() |
|
|