Spaces:
Sleeping
Sleeping
File size: 4,957 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import logging
import warnings
from typing import List, Sequence
import pytorch_lightning as pl
import rich.syntax
import rich.tree
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging
class LoggingContext:
def __init__(self, logger, level=None, handler=None, close=True):
self.logger = logger
self.level = level
self.handler = handler
self.close = close
def __enter__(self):
if self.level is not None:
self.old_level = self.logger.level
self.logger.setLevel(self.level)
if self.handler:
self.logger.addHandler(self.handler)
def __exit__(self, et, ev, tb):
if self.level is not None:
self.logger.setLevel(self.old_level)
if self.handler:
self.logger.removeHandler(self.handler)
if self.handler and self.close:
self.handler.close()
# implicit return of None => don't swallow exceptions
def get_logger(name=__name__) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
def extras(config: DictConfig) -> None:
"""A couple of optional utilities, controlled by main config file:
- disabling warnings
- forcing debug friendly configuration
- verifying experiment name is set when running in experiment mode
Modifies DictConfig in place.
Args:
config (DictConfig): Configuration composed by Hydra.
"""
log = get_logger(__name__)
# disable python warnings if <config.ignore_warnings=True>
if config.get("ignore_warnings"):
log.info("Disabling python warnings! <config.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# verify experiment name is set when running in experiment mode
if config.get("experiment_mode") and not config.get("name"):
log.info(
"Running in experiment mode without the experiment name specified! "
"Use `python run.py mode=exp name=experiment_name`"
)
log.info("Exiting...")
exit()
# force debugger friendly configuration if <config.trainer.fast_dev_run=True>
# debuggers don't like GPUs and multiprocessing
if config.trainer.get("fast_dev_run"):
log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
if config.trainer.get("gpus"):
config.trainer.gpus = 0
if config.datamodule.get("pin_memory"):
config.datamodule.pin_memory = False
if config.datamodule.get("num_workers"):
config.datamodule.num_workers = 0
@rank_zero_only
def print_config(
config: DictConfig,
fields: Sequence[str] = (
"trainer",
"model",
"datamodule",
"train",
"eval",
"callbacks",
"logger",
"seed",
"name",
),
resolve: bool = True,
) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
fields (Sequence[str], optional): Determines which main fields from config will
be printed and in what order.
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
"""
style = "dim"
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, DictConfig):
branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
rich.print(tree)
with open("config_tree.txt", "w") as fp:
rich.print(tree, file=fp)
def finish(
config: DictConfig,
model: pl.LightningModule,
datamodule: pl.LightningDataModule,
trainer: pl.Trainer,
callbacks: List[pl.Callback],
logger: List[pl.loggers.LightningLoggerBase],
) -> None:
"""Makes sure everything closed properly."""
# without this sweeps with wandb logger might crash!
for lg in logger:
if isinstance(lg, pl.loggers.wandb.WandbLogger):
import wandb
wandb.finish()
|