Spaces:
Sleeping
Sleeping
File size: 5,362 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import List, Optional, Sequence
from pathlib import Path
import hydra
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src.utils import utils
log = utils.get_logger(__name__)
def last_modification_time(path):
"""Including files / directory 1-level below the path
"""
path = Path(path)
if path.is_file():
return path.stat().st_mtime
elif path.is_dir():
return max(child.stat().st_mtime for child in path.iterdir())
else:
return None
def train(config: DictConfig) -> Optional[float]:
"""Contains training pipeline.
Instantiates all PyTorch Lightning objects from config.
Args:
config (DictConfig): Configuration composed by Hydra.
Returns:
Optional[float]: Metric score for hyperparameter optimization.
"""
# Set seed for random number generators in pytorch, numpy and python.random
if config.get("seed"):
seed_everything(config.seed, workers=True)
# We want to add fields to config so need to call OmegaConf.set_struct
OmegaConf.set_struct(config, False)
# Init lightning model
model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, _recursive_=False)
datamodule: LightningDataModule = model._datamodule
# Init lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config.callbacks.items():
if cb_conf is not None and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config.logger.items():
if lg_conf is not None and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
ckpt_cfg = {}
if config.get('resume'):
try:
checkpoint_path = Path(config.callbacks.model_checkpoint.dirpath)
if checkpoint_path.is_dir():
last_ckpt = checkpoint_path / 'last.ckpt'
autosave_ckpt = checkpoint_path / '.pl_auto_save.ckpt'
if not (last_ckpt.exists() or autosave_ckpt.exists()):
raise FileNotFoundError("Resume requires either last.ckpt or .pl_autosave.ckpt")
if ((not last_ckpt.exists())
or (autosave_ckpt.exists()
and last_modification_time(autosave_ckpt) > last_modification_time(last_ckpt))):
# autosave_ckpt = autosave_ckpt.replace(autosave_ckpt.with_name('.pl_auto_save_loaded.ckpt'))
checkpoint_path = autosave_ckpt
else:
checkpoint_path = last_ckpt
# DeepSpeed's checkpoint is a directory, not a file
if checkpoint_path.is_file() or checkpoint_path.is_dir():
ckpt_cfg = {'ckpt_path': str(checkpoint_path)}
else:
log.info(f'Checkpoint file {str(checkpoint_path)} not found. Will start training from scratch')
except (KeyError, FileNotFoundError):
pass
# Configure ddp automatically
n_devices = config.trainer.get('devices', 1)
if isinstance(n_devices, Sequence): # trainer.devices could be [1, 3] for example
n_devices = len(n_devices)
if n_devices > 1 and config.trainer.get('strategy', None) is None:
config.trainer.strategy = dict(
_target_='pytorch_lightning.strategies.DDPStrategy',
find_unused_parameters=False,
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
)
# Init lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger)
# Train the model
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, **ckpt_cfg)
# Evaluate model on test set, using the best model achieved during training
if config.get("test_after_training") and not config.trainer.get("fast_dev_run"):
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
# Print path to best checkpoint
if not config.trainer.get("fast_dev_run"):
log.info(f"Best model ckpt: {trainer.checkpoint_callback.best_model_path}")
# Return metric score for hyperparameter optimization
optimized_metric = config.get("optimized_metric")
if optimized_metric:
return trainer.callback_metrics[optimized_metric]
|