Spaces:
Sleeping
Sleeping
import logging | |
import warnings | |
from typing import List, Sequence | |
import pytorch_lightning as pl | |
import rich.syntax | |
import rich.tree | |
from omegaconf import DictConfig, OmegaConf | |
from pytorch_lightning.utilities import rank_zero_only | |
# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging | |
class LoggingContext: | |
def __init__(self, logger, level=None, handler=None, close=True): | |
self.logger = logger | |
self.level = level | |
self.handler = handler | |
self.close = close | |
def __enter__(self): | |
if self.level is not None: | |
self.old_level = self.logger.level | |
self.logger.setLevel(self.level) | |
if self.handler: | |
self.logger.addHandler(self.handler) | |
def __exit__(self, et, ev, tb): | |
if self.level is not None: | |
self.logger.setLevel(self.old_level) | |
if self.handler: | |
self.logger.removeHandler(self.handler) | |
if self.handler and self.close: | |
self.handler.close() | |
# implicit return of None => don't swallow exceptions | |
def get_logger(name=__name__) -> logging.Logger: | |
"""Initializes multi-GPU-friendly python logger.""" | |
logger = logging.getLogger(name) | |
# this ensures all logging levels get marked with the rank zero decorator | |
# otherwise logs would get multiplied for each GPU process in multi-GPU setup | |
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): | |
setattr(logger, level, rank_zero_only(getattr(logger, level))) | |
return logger | |
def extras(config: DictConfig) -> None: | |
"""A couple of optional utilities, controlled by main config file: | |
- disabling warnings | |
- forcing debug friendly configuration | |
- verifying experiment name is set when running in experiment mode | |
Modifies DictConfig in place. | |
Args: | |
config (DictConfig): Configuration composed by Hydra. | |
""" | |
log = get_logger(__name__) | |
# disable python warnings if <config.ignore_warnings=True> | |
if config.get("ignore_warnings"): | |
log.info("Disabling python warnings! <config.ignore_warnings=True>") | |
warnings.filterwarnings("ignore") | |
# verify experiment name is set when running in experiment mode | |
if config.get("experiment_mode") and not config.get("name"): | |
log.info( | |
"Running in experiment mode without the experiment name specified! " | |
"Use `python run.py mode=exp name=experiment_name`" | |
) | |
log.info("Exiting...") | |
exit() | |
# force debugger friendly configuration if <config.trainer.fast_dev_run=True> | |
# debuggers don't like GPUs and multiprocessing | |
if config.trainer.get("fast_dev_run"): | |
log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>") | |
if config.trainer.get("gpus"): | |
config.trainer.gpus = 0 | |
if config.datamodule.get("pin_memory"): | |
config.datamodule.pin_memory = False | |
if config.datamodule.get("num_workers"): | |
config.datamodule.num_workers = 0 | |
def print_config( | |
config: DictConfig, | |
fields: Sequence[str] = ( | |
"trainer", | |
"model", | |
"datamodule", | |
"train", | |
"eval", | |
"callbacks", | |
"logger", | |
"seed", | |
"name", | |
), | |
resolve: bool = True, | |
) -> None: | |
"""Prints content of DictConfig using Rich library and its tree structure. | |
Args: | |
config (DictConfig): Configuration composed by Hydra. | |
fields (Sequence[str], optional): Determines which main fields from config will | |
be printed and in what order. | |
resolve (bool, optional): Whether to resolve reference fields of DictConfig. | |
""" | |
style = "dim" | |
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) | |
for field in fields: | |
branch = tree.add(field, style=style, guide_style=style) | |
config_section = config.get(field) | |
branch_content = str(config_section) | |
if isinstance(config_section, DictConfig): | |
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) | |
branch.add(rich.syntax.Syntax(branch_content, "yaml")) | |
rich.print(tree) | |
with open("config_tree.txt", "w") as fp: | |
rich.print(tree, file=fp) | |
def finish( | |
config: DictConfig, | |
model: pl.LightningModule, | |
datamodule: pl.LightningDataModule, | |
trainer: pl.Trainer, | |
callbacks: List[pl.Callback], | |
logger: List[pl.loggers.LightningLoggerBase], | |
) -> None: | |
"""Makes sure everything closed properly.""" | |
# without this sweeps with wandb logger might crash! | |
for lg in logger: | |
if isinstance(lg, pl.loggers.wandb.WandbLogger): | |
import wandb | |
wandb.finish() | |