Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Wrapper around various loggers and progress bars (e.g., tqdm). | |
""" | |
import atexit | |
import json | |
import logging | |
import os | |
import sys | |
from collections import OrderedDict | |
from contextlib import contextmanager | |
from numbers import Number | |
from typing import Optional | |
import torch | |
from .meters import AverageMeter, StopwatchMeter, TimeMeter | |
logger = logging.getLogger(__name__) | |
def progress_bar( | |
iterator, | |
log_format: Optional[str] = None, | |
log_interval: int = 100, | |
log_file: Optional[str] = None, | |
epoch: Optional[int] = None, | |
prefix: Optional[str] = None, | |
aim_repo: Optional[str] = None, | |
aim_run_hash: Optional[str] = None, | |
aim_param_checkpoint_dir: Optional[str] = None, | |
tensorboard_logdir: Optional[str] = None, | |
default_log_format: str = "tqdm", | |
wandb_project: Optional[str] = None, | |
wandb_run_name: Optional[str] = None, | |
azureml_logging: Optional[bool] = False, | |
): | |
if log_format is None: | |
log_format = default_log_format | |
if log_file is not None: | |
handler = logging.FileHandler(filename=log_file) | |
logger.addHandler(handler) | |
if log_format == "tqdm" and not sys.stderr.isatty(): | |
log_format = "simple" | |
if log_format == "json": | |
bar = JsonProgressBar(iterator, epoch, prefix, log_interval) | |
elif log_format == "none": | |
bar = NoopProgressBar(iterator, epoch, prefix) | |
elif log_format == "simple": | |
bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) | |
elif log_format == "tqdm": | |
bar = TqdmProgressBar(iterator, epoch, prefix) | |
else: | |
raise ValueError("Unknown log format: {}".format(log_format)) | |
if aim_repo: | |
bar = AimProgressBarWrapper( | |
bar, | |
aim_repo=aim_repo, | |
aim_run_hash=aim_run_hash, | |
aim_param_checkpoint_dir=aim_param_checkpoint_dir, | |
) | |
if tensorboard_logdir: | |
try: | |
# [FB only] custom wrapper for TensorBoard | |
import palaas # noqa | |
from .fb_tbmf_wrapper import FbTbmfWrapper | |
bar = FbTbmfWrapper(bar, log_interval) | |
except ImportError: | |
bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) | |
if wandb_project: | |
bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) | |
if azureml_logging: | |
bar = AzureMLProgressBarWrapper(bar) | |
return bar | |
def build_progress_bar( | |
args, | |
iterator, | |
epoch: Optional[int] = None, | |
prefix: Optional[str] = None, | |
default: str = "tqdm", | |
no_progress_bar: str = "none", | |
): | |
"""Legacy wrapper that takes an argparse.Namespace.""" | |
if getattr(args, "no_progress_bar", False): | |
default = no_progress_bar | |
if getattr(args, "distributed_rank", 0) == 0: | |
tensorboard_logdir = getattr(args, "tensorboard_logdir", None) | |
else: | |
tensorboard_logdir = None | |
return progress_bar( | |
iterator, | |
log_format=args.log_format, | |
log_interval=args.log_interval, | |
epoch=epoch, | |
prefix=prefix, | |
tensorboard_logdir=tensorboard_logdir, | |
default_log_format=default, | |
) | |
def format_stat(stat): | |
if isinstance(stat, Number): | |
stat = "{:g}".format(stat) | |
elif isinstance(stat, AverageMeter): | |
stat = "{:.3f}".format(stat.avg) | |
elif isinstance(stat, TimeMeter): | |
stat = "{:g}".format(round(stat.avg)) | |
elif isinstance(stat, StopwatchMeter): | |
stat = "{:g}".format(round(stat.sum)) | |
elif torch.is_tensor(stat): | |
stat = stat.tolist() | |
return stat | |
class BaseProgressBar(object): | |
"""Abstract class for progress bars.""" | |
def __init__(self, iterable, epoch=None, prefix=None): | |
self.iterable = iterable | |
self.n = getattr(iterable, "n", 0) | |
self.epoch = epoch | |
self.prefix = "" | |
if epoch is not None: | |
self.prefix += "epoch {:03d}".format(epoch) | |
if prefix is not None: | |
self.prefix += (" | " if self.prefix != "" else "") + prefix | |
def __len__(self): | |
return len(self.iterable) | |
def __enter__(self): | |
return self | |
def __exit__(self, *exc): | |
return False | |
def __iter__(self): | |
raise NotImplementedError | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats according to log_interval.""" | |
raise NotImplementedError | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
raise NotImplementedError | |
def update_config(self, config): | |
"""Log latest configuration.""" | |
pass | |
def _str_commas(self, stats): | |
return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) | |
def _str_pipes(self, stats): | |
return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) | |
def _format_stats(self, stats): | |
postfix = OrderedDict(stats) | |
# Preprocess stats according to datatype | |
for key in postfix.keys(): | |
postfix[key] = str(format_stat(postfix[key])) | |
return postfix | |
def rename_logger(logger, new_name): | |
old_name = logger.name | |
if new_name is not None: | |
logger.name = new_name | |
yield logger | |
logger.name = old_name | |
class JsonProgressBar(BaseProgressBar): | |
"""Log output in JSON format.""" | |
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): | |
super().__init__(iterable, epoch, prefix) | |
self.log_interval = log_interval | |
self.i = None | |
self.size = None | |
def __iter__(self): | |
self.size = len(self.iterable) | |
for i, obj in enumerate(self.iterable, start=self.n): | |
self.i = i | |
yield obj | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats according to log_interval.""" | |
step = step or self.i or 0 | |
if step > 0 and self.log_interval is not None and step % self.log_interval == 0: | |
update = ( | |
self.epoch - 1 + (self.i + 1) / float(self.size) | |
if self.epoch is not None | |
else None | |
) | |
stats = self._format_stats(stats, epoch=self.epoch, update=update) | |
with rename_logger(logger, tag): | |
logger.info(json.dumps(stats)) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
self.stats = stats | |
if tag is not None: | |
self.stats = OrderedDict( | |
[(tag + "_" + k, v) for k, v in self.stats.items()] | |
) | |
stats = self._format_stats(self.stats, epoch=self.epoch) | |
with rename_logger(logger, tag): | |
logger.info(json.dumps(stats)) | |
def _format_stats(self, stats, epoch=None, update=None): | |
postfix = OrderedDict() | |
if epoch is not None: | |
postfix["epoch"] = epoch | |
if update is not None: | |
postfix["update"] = round(update, 3) | |
# Preprocess stats according to datatype | |
for key in stats.keys(): | |
postfix[key] = format_stat(stats[key]) | |
return postfix | |
class NoopProgressBar(BaseProgressBar): | |
"""No logging.""" | |
def __init__(self, iterable, epoch=None, prefix=None): | |
super().__init__(iterable, epoch, prefix) | |
def __iter__(self): | |
for obj in self.iterable: | |
yield obj | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats according to log_interval.""" | |
pass | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
pass | |
class SimpleProgressBar(BaseProgressBar): | |
"""A minimal logger for non-TTY environments.""" | |
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): | |
super().__init__(iterable, epoch, prefix) | |
self.log_interval = log_interval | |
self.i = None | |
self.size = None | |
def __iter__(self): | |
self.size = len(self.iterable) | |
for i, obj in enumerate(self.iterable, start=self.n): | |
self.i = i | |
yield obj | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats according to log_interval.""" | |
step = step or self.i or 0 | |
if step > 0 and self.log_interval is not None and step % self.log_interval == 0: | |
stats = self._format_stats(stats) | |
postfix = self._str_commas(stats) | |
with rename_logger(logger, tag): | |
logger.info( | |
"{}: {:5d} / {:d} {}".format( | |
self.prefix, self.i + 1, self.size, postfix | |
) | |
) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
postfix = self._str_pipes(self._format_stats(stats)) | |
with rename_logger(logger, tag): | |
logger.info("{} | {}".format(self.prefix, postfix)) | |
class TqdmProgressBar(BaseProgressBar): | |
"""Log to tqdm.""" | |
def __init__(self, iterable, epoch=None, prefix=None): | |
super().__init__(iterable, epoch, prefix) | |
from tqdm import tqdm | |
self.tqdm = tqdm( | |
iterable, | |
self.prefix, | |
leave=False, | |
disable=(logger.getEffectiveLevel() > logging.INFO), | |
) | |
def __iter__(self): | |
return iter(self.tqdm) | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats according to log_interval.""" | |
self.tqdm.set_postfix(self._format_stats(stats), refresh=False) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
postfix = self._str_pipes(self._format_stats(stats)) | |
with rename_logger(logger, tag): | |
logger.info("{} | {}".format(self.prefix, postfix)) | |
try: | |
import functools | |
from aim import Repo as AimRepo | |
def get_aim_run(repo, run_hash): | |
from aim import Run | |
return Run(run_hash=run_hash, repo=repo) | |
except ImportError: | |
get_aim_run = None | |
AimRepo = None | |
class AimProgressBarWrapper(BaseProgressBar): | |
"""Log to Aim.""" | |
def __init__(self, wrapped_bar, aim_repo, aim_run_hash, aim_param_checkpoint_dir): | |
self.wrapped_bar = wrapped_bar | |
if get_aim_run is None: | |
self.run = None | |
logger.warning("Aim not found, please install with: pip install aim") | |
else: | |
logger.info(f"Storing logs at Aim repo: {aim_repo}") | |
if not aim_run_hash: | |
# Find run based on save_dir parameter | |
query = f"run.checkpoint.save_dir == '{aim_param_checkpoint_dir}'" | |
try: | |
runs_generator = AimRepo(aim_repo).query_runs(query) | |
run = next(runs_generator.iter_runs()) | |
aim_run_hash = run.run.hash | |
except Exception: | |
pass | |
if aim_run_hash: | |
logger.info(f"Appending to run: {aim_run_hash}") | |
self.run = get_aim_run(aim_repo, aim_run_hash) | |
def __iter__(self): | |
return iter(self.wrapped_bar) | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats to Aim.""" | |
self._log_to_aim(stats, tag, step) | |
self.wrapped_bar.log(stats, tag=tag, step=step) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
self._log_to_aim(stats, tag, step) | |
self.wrapped_bar.print(stats, tag=tag, step=step) | |
def update_config(self, config): | |
"""Log latest configuration.""" | |
if self.run is not None: | |
for key in config: | |
self.run.set(key, config[key], strict=False) | |
self.wrapped_bar.update_config(config) | |
def _log_to_aim(self, stats, tag=None, step=None): | |
if self.run is None: | |
return | |
if step is None: | |
step = stats["num_updates"] | |
if "train" in tag: | |
context = {"tag": tag, "subset": "train"} | |
elif "val" in tag: | |
context = {"tag": tag, "subset": "val"} | |
else: | |
context = {"tag": tag} | |
for key in stats.keys() - {"num_updates"}: | |
self.run.track(stats[key], name=key, step=step, context=context) | |
try: | |
_tensorboard_writers = {} | |
from torch.utils.tensorboard import SummaryWriter | |
except ImportError: | |
try: | |
from tensorboardX import SummaryWriter | |
except ImportError: | |
SummaryWriter = None | |
def _close_writers(): | |
for w in _tensorboard_writers.values(): | |
w.close() | |
atexit.register(_close_writers) | |
class TensorboardProgressBarWrapper(BaseProgressBar): | |
"""Log to tensorboard.""" | |
def __init__(self, wrapped_bar, tensorboard_logdir): | |
self.wrapped_bar = wrapped_bar | |
self.tensorboard_logdir = tensorboard_logdir | |
if SummaryWriter is None: | |
logger.warning( | |
"tensorboard not found, please install with: pip install tensorboard" | |
) | |
def _writer(self, key): | |
if SummaryWriter is None: | |
return None | |
_writers = _tensorboard_writers | |
if key not in _writers: | |
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) | |
_writers[key].add_text("sys.argv", " ".join(sys.argv)) | |
return _writers[key] | |
def __iter__(self): | |
return iter(self.wrapped_bar) | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats to tensorboard.""" | |
self._log_to_tensorboard(stats, tag, step) | |
self.wrapped_bar.log(stats, tag=tag, step=step) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
self._log_to_tensorboard(stats, tag, step) | |
self.wrapped_bar.print(stats, tag=tag, step=step) | |
def update_config(self, config): | |
"""Log latest configuration.""" | |
# TODO add hparams to Tensorboard | |
self.wrapped_bar.update_config(config) | |
def _log_to_tensorboard(self, stats, tag=None, step=None): | |
writer = self._writer(tag or "") | |
if writer is None: | |
return | |
if step is None: | |
step = stats["num_updates"] | |
for key in stats.keys() - {"num_updates"}: | |
if isinstance(stats[key], AverageMeter): | |
writer.add_scalar(key, stats[key].val, step) | |
elif isinstance(stats[key], Number): | |
writer.add_scalar(key, stats[key], step) | |
elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: | |
writer.add_scalar(key, stats[key].item(), step) | |
writer.flush() | |
try: | |
import wandb | |
except ImportError: | |
wandb = None | |
class WandBProgressBarWrapper(BaseProgressBar): | |
"""Log to Weights & Biases.""" | |
def __init__(self, wrapped_bar, wandb_project, run_name=None): | |
self.wrapped_bar = wrapped_bar | |
if wandb is None: | |
logger.warning("wandb not found, pip install wandb") | |
return | |
# reinit=False to ensure if wandb.init() is called multiple times | |
# within one process it still references the same run | |
wandb.init(project=wandb_project, reinit=False, name=run_name) | |
def __iter__(self): | |
return iter(self.wrapped_bar) | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats to tensorboard.""" | |
self._log_to_wandb(stats, tag, step) | |
self.wrapped_bar.log(stats, tag=tag, step=step) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats.""" | |
self._log_to_wandb(stats, tag, step) | |
self.wrapped_bar.print(stats, tag=tag, step=step) | |
def update_config(self, config): | |
"""Log latest configuration.""" | |
if wandb is not None: | |
wandb.config.update(config) | |
self.wrapped_bar.update_config(config) | |
def _log_to_wandb(self, stats, tag=None, step=None): | |
if wandb is None: | |
return | |
if step is None: | |
step = stats["num_updates"] | |
prefix = "" if tag is None else tag + "/" | |
for key in stats.keys() - {"num_updates"}: | |
if isinstance(stats[key], AverageMeter): | |
wandb.log({prefix + key: stats[key].val}, step=step) | |
elif isinstance(stats[key], Number): | |
wandb.log({prefix + key: stats[key]}, step=step) | |
try: | |
from azureml.core import Run | |
except ImportError: | |
Run = None | |
class AzureMLProgressBarWrapper(BaseProgressBar): | |
"""Log to Azure ML""" | |
def __init__(self, wrapped_bar): | |
self.wrapped_bar = wrapped_bar | |
if Run is None: | |
logger.warning("azureml.core not found, pip install azureml-core") | |
return | |
self.run = Run.get_context() | |
def __exit__(self, *exc): | |
if Run is not None: | |
self.run.complete() | |
return False | |
def __iter__(self): | |
return iter(self.wrapped_bar) | |
def log(self, stats, tag=None, step=None): | |
"""Log intermediate stats to AzureML""" | |
self._log_to_azureml(stats, tag, step) | |
self.wrapped_bar.log(stats, tag=tag, step=step) | |
def print(self, stats, tag=None, step=None): | |
"""Print end-of-epoch stats""" | |
self._log_to_azureml(stats, tag, step) | |
self.wrapped_bar.print(stats, tag=tag, step=step) | |
def update_config(self, config): | |
"""Log latest configuration.""" | |
self.wrapped_bar.update_config(config) | |
def _log_to_azureml(self, stats, tag=None, step=None): | |
if Run is None: | |
return | |
if step is None: | |
step = stats["num_updates"] | |
prefix = "" if tag is None else tag + "/" | |
for key in stats.keys() - {"num_updates"}: | |
name = prefix + key | |
if isinstance(stats[key], AverageMeter): | |
self.run.log_row(name=name, **{"step": step, key: stats[key].val}) | |
elif isinstance(stats[key], Number): | |
self.run.log_row(name=name, **{"step": step, key: stats[key]}) | |