import os from typing import Dict, List, Any import sys rootDir = os.path.abspath(os.path.dirname(__file__)) sys.path.append(rootDir) from uaiDiffusers.common.imageRequest import ImageRequest from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler import torch from uaiDiffusers.uaiDiffusers import ImagesToBase64 import torch class EndpointHandler: def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.pipe = None self.modelName = "" base = "sd-community/sdxl-flash" baseReq = ImageRequest() baseReq.model = base print(f"Loading model: {base}") self.LoadModel(baseReq) def LoadModel(self, request): base = "sd-community/sdxl-flash" if request.model == "default": request.model = base else: base = request.model if self.pipe is None: del self.pipe torch.cuda.empty_cache() self.pipe = StableDiffusionXLPipeline.from_pretrained(base).to("cuda") # Ensure sampler uses "trailing" timesteps. self.pipe.scheduler = DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: input (:obj: `str` | `PIL.Image` | `np.array`) seed (:obj: `int`) prompt (:obj: `str`) negative_prompt (:obj: `str`) num_images_per_prompt (:obj: `int`) steps (:obj: `int`) guidance_scale (:obj: `float`) width (:obj: `int`) height (:obj: `int`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ # inputs = data.pop("parameters", data) request = ImageRequest.FromDict(data) response = self.__runProcess__(request) return response def __runProcess__(self, request: ImageRequest) -> List[Dict[str, Any]]: """ Run SDXL Lightning pipeline """ self.LoadModel(request) # Ensure using the same inference steps as the loaded model and CFG set to 0. images = self.pipe(request.prompt, negative_prompt = request.negative_prompt, num_inference_steps=request.steps, guidance_scale=request.guidance_scale, num_images_per_prompt=request.num_images_per_prompt ).images return {"media":[{"media":ImagesToBase64(img)} for img in images]}