diffusion / lib /inference.py
adamelliotfields's picture
Progress bar improvements
10d9721 verified
raw
history blame
9.97 kB
import os
import re
import time
from datetime import datetime
from itertools import product
import gradio as gr
import numpy as np
import spaces
import torch
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
from PIL import Image
from .config import Config
from .loader import Loader
from .logger import Logger
from .utils import load_json
def parse_prompt_with_arrays(prompt: str) -> list[str]:
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
if not arrays:
return [prompt]
tokens = [item.split(",") for item in arrays] # [("a", "b"), ("1", "2")]
combinations = list(product(*tokens)) # [("a", "1"), ("a", "2"), ("b", "1"), ("b", "2")]
# find all the arrays in the prompt and replace them with tokens
prompts = []
for combo in combinations:
current_prompt = prompt
for i, token in enumerate(combo):
current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
prompts.append(current_prompt)
return prompts
def apply_style(positive_prompt, negative_prompt, style_id="none"):
if style_id.lower() == "none":
return (positive_prompt, negative_prompt)
styles = load_json("./data/styles.json")
style = styles.get(style_id)
if style is None:
return (positive_prompt, negative_prompt)
style_base = styles.get("_base", {})
return (
style.get("positive")
.format(prompt=positive_prompt, _base=style_base.get("positive"))
.strip(),
style.get("negative")
.format(prompt=negative_prompt, _base=style_base.get("negative"))
.strip(),
)
def prepare_image(input, size=None):
image = None
if isinstance(input, Image.Image):
image = input
if isinstance(input, np.ndarray):
image = Image.fromarray(input)
if isinstance(input, str):
if os.path.isfile(input):
image = Image.open(input)
if image is not None:
image = image.convert("RGB")
if size is not None:
image = image.resize(size, Image.Resampling.LANCZOS)
if image is not None:
return image
else:
raise ValueError("Invalid image prompt")
def gpu_duration(**kwargs):
loading = 20
duration = 10
width = kwargs.get("width", 512)
height = kwargs.get("height", 512)
scale = kwargs.get("scale", 1)
num_images = kwargs.get("num_images", 1)
size = width * height
if size > 500_000:
duration += 5
if scale == 4:
duration += 5
return loading + (duration * num_images)
@spaces.GPU(duration=gpu_duration)
def generate(
positive_prompt,
negative_prompt="",
image_prompt=None,
ip_image_prompt=None,
ip_face=False,
lora_1=None,
lora_1_weight=0.0,
lora_2=None,
lora_2_weight=0.0,
embeddings=[],
style=None,
seed=None,
model="Lykon/dreamshaper-8",
scheduler="DDIM",
width=512,
height=512,
guidance_scale=7.5,
inference_steps=40,
denoising_strength=0.8,
deepcache=1,
scale=1,
num_images=1,
karras=False,
taesd=False,
freeu=False,
clip_skip=False,
Info=None,
Error=Exception,
Progress=None,
progress=gr.Progress(track_tqdm=True),
):
if not torch.cuda.is_available():
raise Error("CUDA not available")
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
if seed is None or seed < 0:
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
CURRENT_STEP = 0
CURRENT_IMAGE = 1
KIND = "img2img" if image_prompt is not None else "txt2img"
EMBEDDINGS_TYPE = (
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
if clip_skip
else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
)
if ip_image_prompt:
IP_ADAPTER = "full-face" if ip_face else "plus"
else:
IP_ADAPTER = ""
if Progress is not None:
TQDM = False
progress_bar = Progress()
progress_bar((0, inference_steps), desc=f"Generating image {CURRENT_IMAGE}/{num_images}")
else:
TQDM = True
progress_bar = None
def callback_on_step_end(pipeline, step, timestep, latents):
nonlocal CURRENT_STEP, CURRENT_IMAGE
if Progress is None:
return latents
strength = denoising_strength if KIND == "img2img" else 1
total_steps = min(int(inference_steps * strength), inference_steps)
CURRENT_STEP = step + 1
progress_bar(
(CURRENT_STEP, total_steps),
desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
)
return latents
start = time.perf_counter()
log = Logger("generate")
log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
loader = Loader()
loader.load(
KIND,
IP_ADAPTER,
model,
scheduler,
karras,
taesd,
freeu,
deepcache,
scale,
TQDM,
)
if loader.pipe is None:
raise Error(f"Error loading {model}")
pipe = loader.pipe
upscaler = None
if scale == 2:
upscaler = loader.upscaler_2x
if scale == 4:
upscaler = loader.upscaler_4x
# load loras
loras = []
weights = []
loras_and_weights = [(lora_1, lora_1_weight), (lora_2, lora_2_weight)]
loras_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "loras"))
for lora, weight in loras_and_weights:
if lora and lora.lower() != "none" and lora not in loras:
config = Config.CIVIT_LORAS.get(lora)
if config:
try:
pipe.load_lora_weights(
loras_dir,
adapter_name=lora,
weight_name=f"{lora}.{config['model_version_id']}.safetensors",
)
weights.append(weight)
loras.append(lora)
except Exception:
raise Error(f"Error loading {config['name']} LoRA")
# unload after generating or if there was an error
try:
if loras:
pipe.set_adapters(loras, adapter_weights=weights)
except Exception:
pipe.unload_lora_weights()
raise Error("Error setting LoRA weights")
# load embeddings
embeddings_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "embeddings"))
for embedding in embeddings:
try:
# wrap embeddings in angle brackets
pipe.load_textual_inversion(
pretrained_model_name_or_path=f"{embeddings_dir}/{embedding}.pt",
token=f"<{embedding}>",
)
except (EnvironmentError, HFValidationError, RepositoryNotFoundError):
raise Error(f"Invalid embedding: {embedding}")
# prompt embeds
compel = Compel(
device=pipe.device,
tokenizer=pipe.tokenizer,
truncate_long_prompts=False,
text_encoder=pipe.text_encoder,
returned_embeddings_type=EMBEDDINGS_TYPE,
dtype_for_device_getter=lambda _: pipe.dtype,
textual_inversion_manager=DiffusersTextualInversionManager(pipe),
)
images = []
current_seed = seed
for i in range(num_images):
try:
generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
positive_prompts = parse_prompt_with_arrays(positive_prompt)
index = i % len(positive_prompts)
positive_styled, negative_styled = apply_style(
positive_prompts[index],
negative_prompt,
style,
)
if negative_styled.startswith("(), "):
negative_styled = negative_styled[4:]
for lora in loras:
positive_styled += f", {Config.CIVIT_LORAS[lora]['trigger']}"
for embedding in embeddings:
negative_styled += f", <{embedding}>"
positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
[compel(positive_styled), compel(negative_styled)]
)
except PromptParser.ParsingException:
raise Error("Invalid prompt")
kwargs = {
"width": width,
"height": height,
"generator": generator,
"prompt_embeds": positive_embeds,
"guidance_scale": guidance_scale,
"num_inference_steps": inference_steps,
"negative_prompt_embeds": negative_embeds,
"output_type": "np" if scale > 1 else "pil",
}
if progress is not None:
kwargs["callback_on_step_end"] = callback_on_step_end
if KIND == "img2img":
kwargs["strength"] = denoising_strength
kwargs["image"] = prepare_image(image_prompt, (width, height))
if IP_ADAPTER:
# don't resize full-face images since they are usually square crops
size = None if ip_face else (width, height)
kwargs["ip_adapter_image"] = prepare_image(ip_image_prompt, size)
try:
image = pipe(**kwargs).images[0]
if scale > 1:
image = upscaler.predict(image)
images.append((image, str(current_seed)))
current_seed += 1
except Exception as e:
raise Error(f"{e}")
finally:
if embeddings:
pipe.unload_textual_inversion()
if loras:
pipe.unload_lora_weights()
CURRENT_STEP = 0
CURRENT_IMAGE += 1
diff = time.perf_counter() - start
msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} done in {diff:.2f}s"
log.info(msg)
if Info:
Info(msg)
return images