Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
import logging
import warnings
from typing import List, Sequence
import pytorch_lightning as pl
import rich.syntax
import rich.tree
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
class LoggingContext:
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
# implicit return of None => don't swallow exceptions
def get_logger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
def extras(config: DictConfig) -> None:
"""A couple of optional utilities, controlled by main config file:
- disabling warnings
- forcing debug friendly configuration
- verifying experiment name is set when running in experiment mode
Modifies DictConfig in place.
Args:
config (DictConfig): Configuration composed by Hydra.
"""
log = get_logger(__name__)
# disable python warnings if <config.ignore_warnings=True>
if config.get("ignore_warnings"):
log.info("Disabling python warnings! <config.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# verify experiment name is set when running in experiment mode
if config.get("experiment_mode") and not config.get("name"):
log.info(
"Running in experiment mode without the experiment name specified! "
"Use `python run.py mode=exp name=experiment_name`"
)
log.info("Exiting...")
exit()
# force debugger friendly configuration if <config.trainer.fast_dev_run=True>
# debuggers don't like GPUs and multiprocessing
if config.trainer.get("fast_dev_run"):
log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
if config.trainer.get("gpus"):
config.trainer.gpus = 0
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):
config.datamodule.num_workers = 0
@rank_zero_only
def print_config(
config: DictConfig,
fields: Sequence[str] = (
"trainer",
"model",
"datamodule",
"train",
"eval",
"callbacks",
"logger",
"seed",
"name",
),
resolve: bool = True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
fields (Sequence[str], optional): Determines which main fields from config will
be printed and in what order.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
with open("config_tree.txt", "w") as fp:
rich.print(tree, file=fp)
def finish(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Makes sure everything closed properly."""
# without this sweeps with wandb logger might crash!
for lg in logger:
if isinstance(lg, pl.loggers.wandb.WandbLogger):
import wandb
wandb.finish()