import logging import warnings from typing import List, Sequence import pytorch_lightning as pl import rich.syntax import rich.tree from omegaconf import DictConfig, OmegaConf from pytorch_lightning.utilities import rank_zero_only # Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging class LoggingContext: def __init__(self, logger, level=None, handler=None, close=True): self.logger = logger self.level = level self.handler = handler self.close = close def __enter__(self): if self.level is not None: self.old_level = self.logger.level self.logger.setLevel(self.level) if self.handler: self.logger.addHandler(self.handler) def __exit__(self, et, ev, tb): if self.level is not None: self.logger.setLevel(self.old_level) if self.handler: self.logger.removeHandler(self.handler) if self.handler and self.close: self.handler.close() # implicit return of None => don't swallow exceptions def get_logger(name=__name__) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" logger = logging.getLogger(name) # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): setattr(logger, level, rank_zero_only(getattr(logger, level))) return logger def extras(config: DictConfig) -> None: """A couple of optional utilities, controlled by main config file: - disabling warnings - forcing debug friendly configuration - verifying experiment name is set when running in experiment mode Modifies DictConfig in place. Args: config (DictConfig): Configuration composed by Hydra. """ log = get_logger(__name__) # disable python warnings if if config.get("ignore_warnings"): log.info("Disabling python warnings! ") warnings.filterwarnings("ignore") # verify experiment name is set when running in experiment mode if config.get("experiment_mode") and not config.get("name"): log.info( "Running in experiment mode without the experiment name specified! " "Use `python run.py mode=exp name=experiment_name`" ) log.info("Exiting...") exit() # force debugger friendly configuration if # debuggers don't like GPUs and multiprocessing if config.trainer.get("fast_dev_run"): log.info("Forcing debugger friendly configuration! ") if config.trainer.get("gpus"): config.trainer.gpus = 0 if config.datamodule.get("pin_memory"): config.datamodule.pin_memory = False if config.datamodule.get("num_workers"): config.datamodule.num_workers = 0 @rank_zero_only def print_config( config: DictConfig, fields: Sequence[str] = ( "trainer", "model", "datamodule", "train", "eval", "callbacks", "logger", "seed", "name", ), resolve: bool = True, ) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. fields (Sequence[str], optional): Determines which main fields from config will be printed and in what order. resolve (bool, optional): Whether to resolve reference fields of DictConfig. """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, DictConfig): branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, "yaml")) rich.print(tree) with open("config_tree.txt", "w") as fp: rich.print(tree, file=fp) def finish( config: DictConfig, model: pl.LightningModule, datamodule: pl.LightningDataModule, trainer: pl.Trainer, callbacks: List[pl.Callback], logger: List[pl.loggers.LightningLoggerBase], ) -> None: """Makes sure everything closed properly.""" # without this sweeps with wandb logger might crash! for lg in logger: if isinstance(lg, pl.loggers.wandb.WandbLogger): import wandb wandb.finish()