from typing import List, Optional from pathlib import Path import torch import hydra from omegaconf import OmegaConf, DictConfig from pytorch_lightning import ( Callback, LightningDataModule, LightningModule, Trainer, seed_everything, ) from pytorch_lightning.loggers import LightningLoggerBase from src.utils import utils log = utils.get_logger(__name__) def remove_prefix(text: str, prefix: str): if text.startswith(prefix): return text[len(prefix) :] return text # or whatever def load_checkpoint(path, device='cpu'): path = Path(path).expanduser() if path.is_dir(): path /= 'last.ckpt' # dst = f'cuda:{torch.cuda.current_device()}' log.info(f'Loading checkpoint from {str(path)}') state_dict = torch.load(path, map_location=device) # T2T-ViT checkpoint is nested in the key 'state_dict_ema' if state_dict.keys() == {'state_dict_ema'}: state_dict = state_dict['state_dict_ema'] # Swin checkpoint is nested in the key 'model' if state_dict.keys() == {'model'}: state_dict = state_dict['model'] # Lightning checkpoint contains extra stuff, we only want the model state dict if 'pytorch-lightning_version' in state_dict: state_dict = {remove_prefix(k, 'model.'): v for k, v in state_dict['state_dict'].items()} return state_dict def evaluate(config: DictConfig) -> None: """Example of inference with trained model. It loads trained image classification model from checkpoint. Then it loads example image and predicts its label. """ # load model from checkpoint # model __init__ parameters will be loaded from ckpt automatically # you can also pass some parameter explicitly to override it # We want to add fields to config so need to call OmegaConf.set_struct OmegaConf.set_struct(config, False) # load model checkpoint_type = config.eval.get('checkpoint_type', 'pytorch') if checkpoint_type not in ['lightning', 'pytorch']: raise NotImplementedError(f'checkpoint_type ${checkpoint_type} not supported') if checkpoint_type == 'lightning': cls = hydra.utils.get_class(config.task._target_) model = cls.load_from_checkpoint(checkpoint_path=config.eval.ckpt) elif checkpoint_type == 'pytorch': model_cfg = config.model_pretrained if 'model_pretrained' in config else None trained_model: LightningModule = hydra.utils.instantiate(config.task, cfg=config, model_cfg=model_cfg, _recursive_=False) if 'ckpt' in config.eval: load_return = trained_model.model.load_state_dict( load_checkpoint(config.eval.ckpt, device=trained_model.device), strict=False ) log.info(load_return) if 'model_pretrained' in config: ... else: model = trained_model datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) # datamodule: LightningDataModule = model._datamodule datamodule.prepare_data() datamodule.setup() # print model hyperparameters log.info(f'Model hyperparameters: {model.hparams}') # Init Lightning callbacks callbacks: List[Callback] = [] if "callbacks" in config: for _, cb_conf in config["callbacks"].items(): if cb_conf is not None and "_target_" in cb_conf: log.info(f"Instantiating callback <{cb_conf._target_}>") callbacks.append(hydra.utils.instantiate(cb_conf)) # Init Lightning loggers logger: List[LightningLoggerBase] = [] if "logger" in config: for _, lg_conf in config["logger"].items(): if lg_conf is not None and "_target_" in lg_conf: log.info(f"Instantiating logger <{lg_conf._target_}>") logger.append(hydra.utils.instantiate(lg_conf)) # Init Lightning trainer log.info(f"Instantiating trainer <{config.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate( config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" ) # Evaluate the model log.info("Starting evaluation!") if config.eval.get('run_val', True): trainer.validate(model=model, datamodule=datamodule) if config.eval.get('run_test', True): trainer.test(model=model, datamodule=datamodule) # Make sure everything closed properly log.info("Finalizing!") utils.finish( config=config, model=model, datamodule=datamodule, trainer=trainer, callbacks=callbacks, logger=logger, )