newer_project / modules /lightning.py
YaTharThShaRma999's picture
Update modules/lightning.py
ca30b35 verified
raw
history blame contribute delete
No virus
5.96 kB
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, DPMSolverSinglestepScheduler, AutoencoderTiny, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import random
class sdxl_lightning_loader:
def __init__(self, ckpt="sdxl_lightning_4step_unet.safetensors", base="stabilityai/stable-diffusion-xl-base-1.0", repo="ByteDance/SDXL-Lightning"):
## Very important loading this tiny vae since it saves roughly 4gb vram and slight increase in speed with almost no quality decrease.
self.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', use_safetensors=True, torch_dtype=torch.float16).to('cuda:0')
self.unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda:0", torch.float16)
self.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda:0"))
self.pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=self.unet, torch_dtype=torch.float16, variant="fp16", vae=self.vae).to("cuda:0")
self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing")
self.pipe.enable_vae_slicing()
self.pipe.enable_vae_tiling()
self.pipe.fuse_qkv_projections()
def infer(self, user_prompt, num_images=None, save=True, height=1024, width=1024, guidance_scale=0.75, neg_prompt=""):
if not num_images:
num_images = 1
prompt = [f"{user_prompt}((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional photo, (highest quality), (best shadow), (best illustration), ultra high resolution, highly detailed CG unified 8K wallpapers, physics-based rendering, cinematic lighting)"]
neg_prompt = [f"{neg_prompt}, (worst quality, low quality, normal quality, lowres, low details, oversaturated, undersaturated, overexposed, underexposed, grayscale, bw, bad photo, bad photography, bad art:1.4), (watermark, signature, text font, username, error, logo, words, letters, digits, autograph, trademark, name:1.2), (blur, blurry, grainy), morbid, ugly, asymmetrical, mutated malformed, mutilated, poorly lit, bad shadow, draft, cropped, out of frame, cut off, censored, jpeg artifacts, out of focus, glitch, duplicate, (airbrushed, cartoon, anime, semi-realistic, cgi, render, blender, digital art, manga, amateur:1.3), (3D ,3D Game, 3D Game Scene, 3D Character:1.1), (bad hands, bad anatomy, bad body, bad face, bad teeth, bad arms, bad legs, deformities:1.3)"]
samples = self.pipe(prompt=prompt, negative_prompt=neg_prompt, height=height, width=width, num_inference_steps=4, guidance_scale=guidance_scale, num_images_per_prompt=num_images).images
random_num = random.randint(1, 100)
output_dict = {}
output_dict["llm_output"] = f"Done. Images's have been generated and displayed. Stored as list variable {random_num}"
output_dict["real_output"] = {"display": samples, "metadata": user_prompt}
output_dict["type"] = "generated_image"
output_dict['name'] = str(random_num)
return output_dict
class sdxl_lightning:
def __init__(self, sdxl_path):
self.pipe = StableDiffusionXLPipeline.from_pretrained(sdxl_path, torch_dtype=torch.float16, variant='fp16').to("cuda:0")
self.pipe.scheduler = DPMSolverSinglestepScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True, timestep_spacing="trailing")
#self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config, timestep_spacing="trailing")
self.pipe.enable_vae_slicing()
self.pipe.enable_vae_tiling()
self.guidance_scale = 1.5
self.neg_prompt = ["(worst quality, low quality, normal quality, lowres, low details, oversaturated, undersaturated, overexposed, underexposed, grayscale, bw, bad photo, bad photography, bad art:1.4), (watermark, signature, text font, username, error, logo, words, letters, digits, autograph, trademark, name:1.2), (blur, blurry, grainy), morbid, ugly, asymmetrical, mutated malformed, mutilated, poorly lit, bad shadow, draft, cropped, out of frame, cut off, censored, jpeg artifacts, out of focus, glitch, duplicate, (airbrushed, cartoon, anime, semi-realistic, cgi, render, blender, digital art, manga, amateur:1.3), (3D ,3D Game, 3D Game Scene, 3D Character:1.1), (bad hands, bad anatomy, bad body, bad face, bad teeth, bad arms, bad legs, deformities:1.3)"]
def infer(self, user_prompt, num_images=None, steps=6, type="portrait", save=True, height=1024, width=1024, neg_prompt="bad quality, deformed", change_neg=False):
if not num_images:
num_images = 1
if type == "portrait":
width, height = 1024, 1024
elif type == 'landscape':
width, height = 1344, 768
images = []
neg_prompt = str(neg_prompt)
user_prompt = str(user_prompt)
if isinstance(num_images, str):
num_images = num_images.replace('"', '').replace("'", '')
num_images = int(num_images)
prompt = f"masterpiece, high quality, 8k, RAW photo, {user_prompt}, perfect color, photograph, cinematic, perfect photograph, incredible photograph, 16k"
if change_neg == True:
neg_prompt = [f"{neg_prompt}"]
samples = self.pipe(prompt=prompt, negative_prompt=self.neg_prompt, height=height, width=width, num_inference_steps=steps, guidance_scale=self.guidance_scale, num_images_per_prompt=num_images).images
for i in range(num_images):
image = f"image{i}.jpeg"
samples[i].save(image)
images.append(image)
output_dict = {}
output_dict["llm_output"] = f"Done. Images's have been generated and displayed."
output_dict["display"] = {'images': images, 'metadata': user_prompt}
output_dict["type"] = "image"
return output_dict