"""Monitor rate of change of loss.""" from __future__ import annotations import torch class FDiffMetrics(Callback): """Rate of change of metrics. tracks and plots the rate of change of metrics effectively taking the numerical derivative of the metrics """ def __init__(self, diff_train_metrics: bool=False, diff_eval_metrics: bool=True): self.diff_train_metrics = diff_train_metrics self.diff_eval_metrics = diff_eval_metrics self.train_prev_loss = None self.train_prev_metric = {} self.eval_prev_metric = {} def batch_end(self, state: State, logger: Logger) -> None: if self.diff_train_metrics: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') loss = state.loss.item() if self.train_prev_loss: logger.log_metrics({'loss/train/total_fdiff': loss - self.train_prev_loss}) self.train_prev_loss = loss for k in self.train_prev_metric.keys(): logger.log_metrics({f'metrics/train/{k}_fdiff': state.train_metric_values[k] - self.train_prev_metric[k]}) for k in state.train_metric_values.keys(): value = state.train_metric_values[k] self.train_prev_metric[k] = value def eval_end(self, state: State, logger: Logger) -> None: if self.diff_eval_metrics: evaluator = state.dataloader_label assert evaluator is not None, 'dataloader should have been set' metrics = list(state.eval_metrics[evaluator].keys()) for k in metrics: mkey = '/'.join(['metrics', evaluator, k]) if mkey in self.eval_prev_metric.keys(): logger.log_metrics({f'{mkey}_fdiff': state.eval_metric_values[k] - self.eval_prev_metric[mkey]}) for k in metrics: mkey = '/'.join(['metrics', evaluator, k]) self.eval_prev_metric[mkey] = state.eval_metric_values[k]