mpt-7b-8k-instruct / fdiff_callback.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
2cc518e verified
raw
history blame
2.06 kB
"""Monitor rate of change of loss."""
from __future__ import annotations
import torch
class FDiffMetrics(Callback):
"""Rate of change of metrics.
tracks and plots the rate of change of metrics effectively taking the
numerical derivative of the metrics
"""
def __init__(self, diff_train_metrics: bool=False, diff_eval_metrics: bool=True):
self.diff_train_metrics = diff_train_metrics
self.diff_eval_metrics = diff_eval_metrics
self.train_prev_loss = None
self.train_prev_metric = {}
self.eval_prev_metric = {}
def batch_end(self, state: State, logger: Logger) -> None:
if self.diff_train_metrics:
if not isinstance(state.loss, torch.Tensor):
raise NotImplementedError('Multiple losses not supported yet')
loss = state.loss.item()
if self.train_prev_loss:
logger.log_metrics({'loss/train/total_fdiff': loss - self.train_prev_loss})
self.train_prev_loss = loss
for k in self.train_prev_metric.keys():
logger.log_metrics({f'metrics/train/{k}_fdiff': state.train_metric_values[k] - self.train_prev_metric[k]})
for k in state.train_metric_values.keys():
value = state.train_metric_values[k]
self.train_prev_metric[k] = value
def eval_end(self, state: State, logger: Logger) -> None:
if self.diff_eval_metrics:
evaluator = state.dataloader_label
assert evaluator is not None, 'dataloader should have been set'
metrics = list(state.eval_metrics[evaluator].keys())
for k in metrics:
mkey = '/'.join(['metrics', evaluator, k])
if mkey in self.eval_prev_metric.keys():
logger.log_metrics({f'{mkey}_fdiff': state.eval_metric_values[k] - self.eval_prev_metric[mkey]})
for k in metrics:
mkey = '/'.join(['metrics', evaluator, k])
self.eval_prev_metric[mkey] = state.eval_metric_values[k]