from diffusers import DDPMScheduler, DiffusionPipeline from typing import List, Any, Union, Type from utils.loader import get_class from copy import deepcopy from modules.loader.module_loader_config import ModuleLoaderConfig import torch import pytorch_lightning as pl import jsonargparse class bcolors: HEADER = '\033[95m' OKBLUE = '\033[94m' OKCYAN = '\033[96m' OKGREEN = '\033[92m' WARNING = '\033[93m' FAIL = '\033[91m' ENDC = '\033[0m' BOLD = '\033[1m' UNDERLINE = '\033[4m' class GenericModuleLoader(): def __init__(self, pipeline_repo: str = None, pipeline_obj: str = None, set_prediction_type: str = "", module_names: List[str] = [ "scheduler", "text_encoder", "tokenizer", "vae", "unet",], module_config: dict[str, Union[ModuleLoaderConfig, torch.nn.Module, Any]] = None, fast_dev_run: Union[int, bool] = False, root_cls: Type[Any] = None, ) -> None: self.module_config = module_config self.pipeline_repo = pipeline_repo self.pipeline_obj = pipeline_obj self.set_prediction_type = set_prediction_type self.module_names = module_names self.fast_dev_run = fast_dev_run self.root_cls = root_cls def load_custom_scheduler(self): module_obj = DDPMScheduler.from_pretrained( self.pipeline_repo, subfolder="scheduler") if len(self.set_prediction_type) > 0: scheduler_config = module_obj.load_config( self.pipeline_repo, subfolder="scheduler") scheduler_config["prediction_type"] = self.set_prediction_type module_obj = module_obj.from_config(scheduler_config) return module_obj def load_pipeline(self): return DiffusionPipeline.from_pretrained(self.pipeline_repo) if self.pipeline_repo is not None else None def __call__(self, trainer: pl.LightningModule, diff_trainer_params): # load diffusers pipeline object if set if self.pipeline_obj is not None: pipe = self.load_pipeline() else: pipe = None if pipe is not None and self.pipeline_obj is not None: # store the entire diffusers pipeline object under the name given by pipeline_obj setattr(trainer, self.pipeline_obj, self.load_pipeline()) for module_name in self.module_names: print(f" --- START: Loading module: {module_name} ---") if module_name not in self.module_config.keys() and pipe is not None: # stores models from already loaded diffusers pipeline module_obj = getattr(pipe, module_name) if module_name == "scheduler": module_obj = self.load_custom_scheduler() setattr(trainer, module_name, module_obj) else: if not isinstance(self.module_config[module_name], ModuleLoaderConfig): # instantiate model by jsonargparse and store it module = self.module_config[module_name] # TODO we want to be able to load ckpt still. config_obj = None else: # instantiate object from class method (as used by Diffusers, e.g. DiffusionPipeline.load_from_pretrained) config_obj = self.module_config[module_name] # retrieve loader class loader_cls = get_class( config_obj.loader_cls_path) # retrieve loader method if config_obj.cls_func != "": # we allow to specify a method for fast loading (e.g. in diffusers, from_config instead of from_pretrained) # makes loading faster for quick testing if not self.fast_dev_run or config_obj.cls_func_fast_dev_run == "": cls_func = getattr( loader_cls, config_obj.cls_func) else: print( f"Model {module_name}: loading fast_dev_run class loader") cls_func = getattr( loader_cls, config_obj.cls_func_fast_dev_run) else: cls_func = loader_cls # retrieve parameters # load parameters specified in diff_trainer_params (so it links them) kwargs_trainer_params = config_obj.kwargs_diff_trainer_params kwargs_diffusers = config_obj.kwargs_diffusers # names of dependent modules that we need as input dependent_modules = config_obj.dependent_modules # names of dependent modules that we need as input. Modules will be cloned dependent_modules_cloned = config_obj.dependent_modules_cloned # model kwargs. Can be just a dict, or a parameter class (derived from modules.params.params_mixin.AsDictMixin) so we have verification of inputs model_params = config_obj.model_params # kwargs used only if on fast_dev_run mode model_params_fast_dev_run = config_obj.model_params_fast_dev_run if model_params is not None: if isinstance(model_params, dict): model_dict = model_params else: model_dict = model_params.to_dict() else: model_dict = {} if (model_params_fast_dev_run is None) or (not self.fast_dev_run): model_params_fast_dev_run = {} else: print( f"Module {module_name}: loading fast_dev_run params") loaded_modules_dict = {} if dependent_modules is not None: for key, dependent_module in dependent_modules.items(): assert hasattr( trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" loaded_modules_dict[key] = getattr( trainer, dependent_module) if dependent_modules_cloned is not None: for key, dependent_module in dependent_modules_cloned.items(): assert hasattr( trainer, dependent_module), f"Module {dependent_module} not available. Set {dependent_module} before module {module_name} in module_loader.module_names. Current order: {self.module_names}" loaded_modules_dict[key] = getattr( trainer, deepcopy(dependent_module)) if kwargs_trainer_params is not None: for key, param in kwargs_trainer_params.items(): if param is not None: kwargs_trainer_params[key] = getattr( diff_trainer_params, param) else: kwargs_trainer_params[key] = diff_trainer_params else: kwargs_trainer_params = {} if kwargs_diffusers is None: kwargs_diffusers = {} else: for key, value in kwargs_diffusers.items(): if key == "torch_dtype": if value == "torch.float16": kwargs_diffusers[key] = torch.float16 kwargs = kwargs_diffusers | loaded_modules_dict | kwargs_trainer_params | model_dict | model_params_fast_dev_run args = config_obj.args # instantiate object module = cls_func(*args, **kwargs) module: torch.nn.Module if self.root_cls is not None: assert isinstance(module, self.root_cls) if config_obj is not None and config_obj.state_dict_path != "" and not self.fast_dev_run: # TODO extend loading to hf spaces print( f" * Loading checkpoint {config_obj.state_dict_path} - STARTED") module_state_dict = torch.load( config_obj.state_dict_path, map_location=torch.device("cpu")) module_state_dict = module_state_dict["state_dict"] if len(config_obj.state_dict_filters) > 0: assert not config_obj.strict_loading ckpt_params_dict = {} for name, param in module.named_parameters(prefix=module_name): for filter_str in config_obj.state_dict_filters: filter_groups = filter_str.split("*") has_all_parts = True for filter_group in filter_groups: has_all_parts = has_all_parts and filter_group in name if has_all_parts: validate_name = name for filter_group in filter_groups: if filter_group in validate_name: shift = validate_name.index( filter_group) validate_name = validate_name[shift+len( filter_group):] else: has_all_parts = False break if has_all_parts: ckpt_params_dict[name[len( module_name+"."):]] = param else: ckpt_params_dict = dict(filter(lambda x: x[0].startswith( module_name), module_state_dict.items())) ckpt_params_dict = { k.split(module_name+".")[1]: v for (k, v) in ckpt_params_dict.items()} if len(ckpt_params_dict) > 0: miss, unex = module.load_state_dict( ckpt_params_dict, strict=config_obj.strict_loading) ckpt_params_dict = {} assert len( unex) == 0, f"Unexpected parameters in checkpoint: {unex}" if len(miss) > 0: print( f"Checkpoint {config_obj.state_dict_path} is missing parameters for module {module_name}.") print(miss) print( f" * Loading checkpoint {config_obj.state_dict_path} - FINISHED") if isinstance(module, jsonargparse.Namespace) or isinstance(module, dict): print(bcolors.WARNING + f"Warning: Seems object {module_name} was not build correct." + bcolors.ENDC) setattr(trainer, module_name, module) print(f" --- FINSHED: Loading module: {module_name} ---")