import torch from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, AutoencoderTiny from huggingface_hub import hf_hub_download from safetensors.torch import load_file import random class sdxl_lightning_loader: def __init__(self, ckpt="sdxl_lightning_4step_unet.safetensors", base="stabilityai/stable-diffusion-xl-base-1.0", repo="ByteDance/SDXL-Lightning"): ## Very important loading this tiny vae since it saves roughly 4gb vram and slight increase in speed with almost no quality decrease. self.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', use_safetensors=True, torch_dtype=torch.float16).to('cuda:0') self.unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda:0", torch.float16) self.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda:0")) self.pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=self.unet, torch_dtype=torch.float16, variant="fp16", vae=self.vae).to("cuda:0") self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing") self.pipe.enable_vae_slicing() self.pipe.enable_vae_tiling() def infer(self, user_prompt, num_images=None, save=True, height=1024, width=1024, guidance_scale=0.75, neg_prompt=""): if not num_images: num_images = 1 prompt = [f"{user_prompt}((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional photo, (highest quality), (best shadow), (best illustration), ultra high resolution, highly detailed CG unified 8K wallpapers, physics-based rendering, cinematic lighting)"] neg_prompt = [f"{neg_prompt}, lowres, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, "] samples = self.pipe(prompt=prompt, negative_prompt=neg_prompt, height=height, width=width, num_inference_steps=4, guidance_scale=guidance_scale, num_images_per_prompt=num_images).images random_num = random.randint(1, 100) output_dict = {} output_dict["llm_output"] = f"Done. Images's have been generated and displayed. Stored as list variable {random_num}" output_dict["real_output"] = {"display": samples, "metadata": user_prompt} output_dict["type"] = "generated_image" output_dict['name'] = str(random_num) return output_dict class sdxl_lightning: def __init__(self, sdxl_path): self.pipe = StableDiffusionXLPipeline.from_pretrained(sdxl_path, torch_dtype=torch.float16).to("cuda:0") self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing") self.pipe.enable_vae_slicing() self.pipe.enable_vae_tiling() def infer(self, user_prompt, num_images=None, type="portrait", save=True, height=1024, width=1024, guidance_scale=0, neg_prompt="bad quality, deformed"): if not num_images: num_images = 1 if type == "portrait": width, height = 1024, 1024 elif type == 'landscape': width, height = 1344, 768 else: pass images = [] prompt = [f"{user_prompt}((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional photo, (highest quality), (best shadow), (best illustration), ultra high resolution, highly detailed CG unified 8K wallpapers, physics-based rendering, cinematic lighting)"] neg_prompt = [f"{neg_prompt}, lowres, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, "] samples = self.pipe(prompt=prompt, negative_prompt=neg_prompt, height=height, width=width, num_inference_steps=4, guidance_scale=guidance_scale, num_images_per_prompt=num_images).images for i in range(num_images): image = f"image{i}.jpeg" samples[i].save(image) images.append(f'{user_prompt}') output_dict = {} output_dict["llm_output"] = f"Done. Images's have been generated and displayed." output_dict["display"] = {'html': images, 'metadata': user_prompt] output_dict["type"] = "image" return output_dict