Spaces:
Sleeping
Sleeping
File size: 4,851 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import List, Optional
from pathlib import Path
import torch
import hydra
from omegaconf import OmegaConf, DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
from pytorch_lightning.loggers import LightningLoggerBase
from src.utils import utils
log = utils.get_logger(__name__)
def remove_prefix(text: str, prefix: str):
if text.startswith(prefix):
return text[len(prefix) :]
return text # or whatever
def load_checkpoint(path, device='cpu'):
path = Path(path).expanduser()
if path.is_dir():
path /= 'last.ckpt'
# dst = f'cuda:{torch.cuda.current_device()}'
log.info(f'Loading checkpoint from {str(path)}')
state_dict = torch.load(path, map_location=device)
# T2T-ViT checkpoint is nested in the key 'state_dict_ema'
if state_dict.keys() == {'state_dict_ema'}:
state_dict = state_dict['state_dict_ema']
# Swin checkpoint is nested in the key 'model'
if state_dict.keys() == {'model'}:
state_dict = state_dict['model']
# Lightning checkpoint contains extra stuff, we only want the model state dict
if 'pytorch-lightning_version' in state_dict:
state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()}
return state_dict
def evaluate(config: DictConfig) -> None:
"""Example of inference with trained model.
It loads trained image classification model from checkpoint.
Then it loads example image and predicts its label.
"""
# load model from checkpoint
# model __init__ parameters will be loaded from ckpt automatically
# you can also pass some parameter explicitly to override it
# We want to add fields to config so need to call OmegaConf.set_struct
OmegaConf.set_struct(config, False)
# load model
checkpoint_type = config.eval.get('checkpoint_type', 'pytorch')
if checkpoint_type not in ['lightning', 'pytorch']:
raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported')
if checkpoint_type == 'lightning':
cls = hydra.utils.get_class(config.task._target_)
model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt)
elif checkpoint_type == 'pytorch':
model_cfg = config.model_pretrained if 'model_pretrained' in config else None
trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config,
model_cfg=model_cfg,
_recursive_=False)
if 'ckpt' in config.eval:
load_return = trained_model.model.load_state_dict(
load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False
)
log.info(load_return)
if 'model_pretrained' in config:
...
else:
model = trained_model
datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
# datamodule: LightningDataModule = model._datamodule
datamodule.prepare_data()
datamodule.setup()
# print model hyperparameters
log.info(f'Model hyperparameters: {model.hparams}')
# Init Lightning callbacks
callbacks: List[Callback] = []
if "callbacks" in config:
for _, cb_conf in config["callbacks"].items():
if cb_conf is not None and "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init Lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config["logger"].items():
if lg_conf is not None and "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
# Init Lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
)
# Evaluate the model
log.info("Starting evaluation!")
if config.eval.get('run_val', True):
trainer.validate(model=model, datamodule=datamodule)
if config.eval.get('run_test', True):
trainer.test(model=model, datamodule=datamodule)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=logger,
)
|