import os import json import importlib from typing import Type, Tuple, Union, List, Dict, Any import torch import diffusers import onnxruntime as ort def extract_device(args: List, kwargs: Dict): device = kwargs.get("device", None) if device is None: for arg in args: if isinstance(arg, torch.device): device = arg return device def move_inference_session(session: ort.InferenceSession, device: torch.device): from modules.devices import device as default_device if default_device.type == "cpu": # CPU-only torch without any other external ops overriding. This transfer will be led to mistake. return session from . import DynamicSessionOptions, TemporalModule from .execution_providers import TORCH_DEVICE_TO_EP previous_provider = session._providers # pylint: disable=protected-access provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else previous_provider path = session._model_path # pylint: disable=protected-access try: return diffusers.OnnxRuntimeModel.load_model(path, provider, DynamicSessionOptions.from_sess_options(session._sess_options)) # pylint: disable=protected-access except Exception: return TemporalModule(previous_provider, path, session._sess_options) # pylint: disable=protected-access def check_diffusers_cache(path: os.PathLike): from modules.shared import opts return opts.diffusers_dir in os.path.abspath(path) def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool: return 'XL' in cls.__name__ def check_cache_onnx(path: os.PathLike) -> bool: if not os.path.isdir(path): return False init_dict_path = os.path.join(path, "model_index.json") if not os.path.isfile(init_dict_path): return False init_dict = None with open(init_dict_path, "r", encoding="utf-8") as file: init_dict = file.read() if "OnnxRuntimeModel" not in init_dict: return False return True def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike): merged: Dict[str, Any] = {} extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path)) for item in extracted: merged.update(item) merged = merged.items() R: Dict[str, Tuple[str]] = {} for k, v in merged: if isinstance(v, list): if k not in cls.__init__.__annotations__: continue R[k] = v return R def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort): lib, atr = item if lib is None or atr is None: return None library = importlib.import_module(lib) attribute = getattr(library, atr) path = os.path.join(path, submodel_name) if issubclass(attribute, diffusers.OnnxRuntimeModel): return diffusers.OnnxRuntimeModel.load_model( os.path.join(path, "model.onnx"), **kwargs_ort, ) if is_sdxl else diffusers.OnnxRuntimeModel.from_pretrained( path, **kwargs_ort, ) return attribute.from_pretrained(path) def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort): loaded = {} for k, v in init_dict.items(): if not isinstance(v, list): loaded[k] = v continue try: loaded[k] = load_submodel(path, is_sdxl, k, v, **kwargs_ort) except Exception: pass return loaded def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline: if os.path.isdir(path): return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort))) else: return cls.from_single_file(path) def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict: if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline: kwargs["safety_checker"] = None kwargs["requires_safety_checker"] = False if cls == diffusers.OnnxStableDiffusionXLPipeline or cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline: kwargs["config"] = {} return kwargs def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool): if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline: return diffusers.OnnxStableDiffusionPipeline if cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline and not is_refiner: return diffusers.OnnxStableDiffusionXLPipeline return cls def get_io_config(submodel: str, is_sdxl: bool): from modules.paths import sd_configs_path with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file: io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"] for axe in io_config["dynamic_axes"]: io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() } return io_config