Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
from pytorch_lightning import Callback | |
from pytorch_lightning.utilities import rank_zero_only | |
import torch | |
from torch.autograd import grad | |
class CausalityMonitor(Callback): | |
r"""Monitor causality of a model by tracking gradient leakage forward in time. | |
In a fully causal model, dy[k]du[s] ~= 0 for all k < s. | |
Args: | |
seq_len (int): Length of the sequence to monitor. | |
input_dim (int): Dimension of the input to monitor. If 0, the callback assumes | |
the task to be language modeling, and skips the embedding layer. If > 0, | |
input_dim is interpreted as the input channel dimension, i.e. D with | |
dummy input of dimension [B, L, D]. | |
Notes: | |
This callback assumes that `pl_module.model` has a `net` or `s4seq` attribute, | |
indicating the primary model to monitor. For LMs, `net` or `s4seq` should | |
be after the embedding layer. | |
""" | |
def __init__(self, seq_len: int = 10, input_dim: int = 0): | |
super().__init__() | |
self.seq_len = seq_len | |
self.input_dim = input_dim | |
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | |
model = pl_module.model | |
with torch.enable_grad(): | |
if self.input_dim == 0: | |
# [MP] LongTensors cannot have gradients - we start from post | |
# embedding in the LM case | |
input_dim = model.d_model | |
x = torch.randn((2, self.seq_len, input_dim), \ | |
requires_grad=True).to(pl_module.device) | |
# [DF] HACK: we need to get the layer that comes after the embedding | |
if hasattr(model, 'net'): | |
y = model.net(x) | |
else: | |
y = model.s4seq(x) | |
else: | |
x = torch.randn(1, self.seq_len, self.input_dim, \ | |
requires_grad=True).to(pl_module.device) | |
y = model(x) | |
stats = {} | |
for i in range(self.seq_len): | |
# total gradients flowing from y_i to x | |
g = grad(y[0,0,i].mean(), x, retain_graph=True, allow_unused=True)[0] | |
g = g[0,i+1:,:].abs().mean() | |
stats[f'stats/causality_{i}'] = g.item() | |
if trainer.loggers is not None: | |
for logger in trainer.loggers: | |
logger.log_metrics(stats, step=trainer.global_step) | |