Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
import time | |
import warnings | |
from itertools import cycle | |
from typing import List, Optional | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import logging | |
from matplotlib import colors as mcolors | |
from visdom import Visdom | |
class AverageMeter(object): | |
""" | |
Computes and stores the average and current value. | |
Tracks the exact history of the added values in every epoch. | |
""" | |
def __init__(self): | |
""" | |
Initialize the structure with empty history and zero-ed moving average. | |
""" | |
self.history = [] | |
self.reset() | |
def reset(self): | |
""" | |
Reset the running average meter. | |
""" | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val: float, n: int = 1, epoch: int = 0): | |
""" | |
Updates the average meter with a value `val`. | |
Args: | |
val: A float to be added to the meter. | |
n: Represents the number of entities to be added. | |
epoch: The epoch to which the number should be added. | |
""" | |
# make sure the history is of the same len as epoch | |
while len(self.history) <= epoch: | |
self.history.append([]) | |
self.history[epoch].append(val / n) | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def get_epoch_averages(self): | |
""" | |
Returns: | |
averages: A list of average values of the metric for each epoch | |
in the history buffer. | |
""" | |
if len(self.history) == 0: | |
return None | |
return [ | |
(float(np.array(h).mean()) if len(h) > 0 else float("NaN")) | |
for h in self.history | |
] | |
class Stats(object): | |
""" | |
Stats logging object useful for gathering statistics of training | |
a deep network in PyTorch. | |
Example: | |
``` | |
# Init stats structure that logs statistics 'objective' and 'top1e'. | |
stats = Stats( ('objective','top1e') ) | |
network = init_net() # init a pytorch module (=neural network) | |
dataloader = init_dataloader() # init a dataloader | |
for epoch in range(10): | |
# start of epoch -> call new_epoch | |
stats.new_epoch() | |
# Iterate over batches. | |
for batch in dataloader: | |
# Run a model and save into a dict of output variables "output" | |
output = network(batch) | |
# stats.update() automatically parses the 'objective' and 'top1e' | |
# from the "output" dict and stores this into the db. | |
stats.update(output) | |
stats.print() # prints the averages over given epoch | |
# Stores the training plots into '/tmp/epoch_stats.pdf' | |
# and plots into a visdom server running at localhost (if running). | |
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') | |
``` | |
""" | |
def __init__( | |
self, | |
log_vars: List[str], | |
verbose: bool = False, | |
epoch: int = -1, | |
plot_file: Optional[str] = None, | |
): | |
""" | |
Args: | |
log_vars: The list of variable names to be logged. | |
verbose: Print status messages. | |
epoch: The initial epoch of the object. | |
plot_file: The path to the file that will hold the training plots. | |
""" | |
self.verbose = verbose | |
self.log_vars = log_vars | |
self.plot_file = plot_file | |
self.hard_reset(epoch=epoch) | |
def reset(self): | |
""" | |
Called before an epoch to clear current epoch buffers. | |
""" | |
stat_sets = list(self.stats.keys()) | |
if self.verbose: | |
print("stats: epoch %d - reset" % self.epoch) | |
self.it = {k: -1 for k in stat_sets} | |
for stat_set in stat_sets: | |
for stat in self.stats[stat_set]: | |
self.stats[stat_set][stat].reset() | |
# Set a new timestamp. | |
self._epoch_start = time.time() | |
def hard_reset(self, epoch: int = -1): | |
""" | |
Erases all logged data. | |
""" | |
self._epoch_start = None | |
self.epoch = epoch | |
if self.verbose: | |
print("stats: epoch %d - hard reset" % self.epoch) | |
self.stats = {} | |
self.reset() | |
def new_epoch(self): | |
""" | |
Initializes a new epoch. | |
""" | |
if self.verbose: | |
print("stats: new epoch %d" % (self.epoch + 1)) | |
self.epoch += 1 # increase epoch counter | |
self.reset() # zero the stats | |
def _gather_value(self, val): | |
if isinstance(val, float): | |
pass | |
else: | |
val = val.data.cpu().numpy() | |
val = float(val.sum()) | |
return val | |
def update(self, preds: dict, stat_set: str = "train"): | |
""" | |
Update the internal logs with metrics of a training step. | |
Each metric is stored as an instance of an AverageMeter. | |
Args: | |
preds: Dict of values to be added to the logs. | |
stat_set: The set of statistics to be updated (e.g. "train", "val"). | |
""" | |
if self.epoch == -1: # uninitialized | |
warnings.warn( | |
"self.epoch==-1 means uninitialized stats structure" | |
" -> new_epoch() called" | |
) | |
self.new_epoch() | |
if stat_set not in self.stats: | |
self.stats[stat_set] = {} | |
self.it[stat_set] = -1 | |
self.it[stat_set] += 1 | |
epoch = self.epoch | |
it = self.it[stat_set] | |
for stat in self.log_vars: | |
if stat not in self.stats[stat_set]: | |
self.stats[stat_set][stat] = AverageMeter() | |
if stat == "sec/it": # compute speed | |
elapsed = time.time() - self._epoch_start | |
time_per_it = float(elapsed) / float(it + 1) | |
val = time_per_it | |
else: | |
if stat in preds: | |
val = self._gather_value(preds[stat]) | |
else: | |
val = None | |
if val is not None and not np.isnan(val): | |
self.stats[stat_set][stat].update(val, epoch=epoch, n=1) | |
def print(self, max_it: Optional[int] = None, stat_set: str = "train"): | |
""" | |
Print the current values of all stored stats. | |
Args: | |
max_it: Maximum iteration number to be displayed. | |
If None, the maximum iteration number is not displayed. | |
stat_set: The set of statistics to be printed. | |
""" | |
epoch = self.epoch | |
stats = self.stats | |
str_out = "" | |
it = self.it[stat_set] | |
stat_str = "" | |
stats_print = sorted(stats[stat_set].keys()) | |
for stat in stats_print: | |
if stats[stat_set][stat].count == 0: | |
continue | |
stat_str += " {0:.12}: {1:1.3f} |".format(stat, stats[stat_set][stat].avg) | |
head_str = f"[{stat_set}] | epoch {epoch} | it {it}" | |
if max_it: | |
head_str += f"/ {max_it}" | |
str_out = f"{head_str} | {stat_str}" | |
logging.info(str_out) | |
def plot_stats( | |
self, | |
viz: Visdom = None, | |
visdom_env: Optional[str] = None, | |
plot_file: Optional[str] = None, | |
): | |
""" | |
Plot the line charts of the history of the stats. | |
Args: | |
viz: The Visdom object holding the connection to a Visdom server. | |
visdom_env: The visdom environment for storing the graphs. | |
plot_file: The path to a file with training plots. | |
""" | |
stat_sets = list(self.stats.keys()) | |
if viz is None: | |
withvisdom = False | |
elif not viz.check_connection(): | |
warnings.warn("Cannot connect to the visdom server! Skipping visdom plots.") | |
withvisdom = False | |
else: | |
withvisdom = True | |
lines = [] | |
for stat in self.log_vars: | |
vals = [] | |
stat_sets_now = [] | |
for stat_set in stat_sets: | |
val = self.stats[stat_set][stat].get_epoch_averages() | |
if val is None: | |
continue | |
else: | |
val = np.array(val).reshape(-1) | |
stat_sets_now.append(stat_set) | |
vals.append(val) | |
if len(vals) == 0: | |
continue | |
vals = np.stack(vals, axis=1) | |
x = np.arange(vals.shape[0]) | |
lines.append((stat_sets_now, stat, x, vals)) | |
if withvisdom: | |
for tmodes, stat, x, vals in lines: | |
title = "%s" % stat | |
opts = {"title": title, "legend": list(tmodes)} | |
for i, (tmode, val) in enumerate(zip(tmodes, vals.T)): | |
update = "append" if i > 0 else None | |
valid = np.where(np.isfinite(val)) | |
if len(valid) == 0: | |
continue | |
viz.line( | |
Y=val[valid], | |
X=x[valid], | |
env=visdom_env, | |
opts=opts, | |
win=f"stat_plot_{title}", | |
name=tmode, | |
update=update, | |
) | |
if plot_file is None: | |
plot_file = self.plot_file | |
if plot_file is not None: | |
print("Exporting stats to %s" % plot_file) | |
ncol = 3 | |
nrow = int(np.ceil(float(len(lines)) / ncol)) | |
matplotlib.rcParams.update({"font.size": 5}) | |
color = cycle(plt.cm.tab10(np.linspace(0, 1, 10))) | |
fig = plt.figure(1) | |
plt.clf() | |
for idx, (tmodes, stat, x, vals) in enumerate(lines): | |
c = next(color) | |
plt.subplot(nrow, ncol, idx + 1) | |
for vali, vals_ in enumerate(vals.T): | |
c_ = c * (1.0 - float(vali) * 0.3) | |
valid = np.where(np.isfinite(vals_)) | |
if len(valid) == 0: | |
continue | |
plt.plot(x[valid], vals_[valid], c=c_, linewidth=1) | |
plt.ylabel(stat) | |
plt.xlabel("epoch") | |
plt.gca().yaxis.label.set_color(c[0:3] * 0.75) | |
plt.legend(tmodes) | |
gcolor = np.array(mcolors.to_rgba("lightgray")) | |
plt.grid( | |
b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4 | |
) | |
plt.grid( | |
b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2 | |
) | |
plt.minorticks_on() | |
plt.tight_layout() | |
plt.show() | |
fig.savefig(plot_file) |