# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py # But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) # Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py # But we pass in the loss to avoid recomputation from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import Tensor from torchmetrics import Metric try: from flash_attn.losses.cross_entropy import CrossEntropyLoss except ImportError: CrossEntropyLoss = torch.nn.CrossEntropyLoss __all__ = ['Perplexity'] class Perplexity(Metric): r""" Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits per word a model needs to represent the sample. Args: kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Examples: >>> import torch >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) >>> target[0, 6:] = -100 >>> metric = Perplexity(ignore_index=-100) >>> metric(preds, target) tensor(5.2545) """ is_differentiable = True higher_is_better = False full_state_update = False total_log_probs: Tensor count: Tensor def __init__(self, **kwargs: Dict[str, Any]): super().__init__(**kwargs) self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), dist_reduce_fx="sum") self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") self.loss_fn = CrossEntropyLoss() def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore """Compute and store intermediate statistics for Perplexity. Args: preds: Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. target: Ground truth values with a shape [batch_size, seq_len]. """ count = target.numel() if loss is None: loss = self.loss_fn(preds, target) self.total_log_probs += loss.double() * count self.count += count def compute(self) -> Tensor: """Compute the Perplexity. Returns: Perplexity """ return torch.exp(self.total_log_probs / self.count)