import spaces
import os
import random
import math
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
import numpy as np
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
StableDiffusionXLPipeline,
)
from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
EulerAncestralDiscreteScheduler,
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
try:
from dotenv import load_dotenv
load_dotenv()
except:
print("failed to import dotenv (this is not a problem on the production)")
HF_TOKEN = os.environ.get("HF_TOKEN")
assert HF_TOKEN is not None
IMAGE_MODEL_REPO_ID = os.environ.get(
"IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0"
)
DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None)
assert DART_V3_REPO_ID is not None
CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "False").lower() == "true"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
TEMPLATE = (
"<|bos|>"
#
"<|rating:general|>"
"{aspect_ratio}"
"<|length:medium|>"
#
""
#
""
#
"{subject}"
)
QUALITY_TAGS = ""
NEGATIVE_PROMPT = "bad quality, worst quality, lowres, bad anatomy, sketch, jpeg artifacts, ugly, poorly drawn, signature, watermark, bad anatomy, bad hands, bad feet, retro, old, 2000s, 2010s, 2011s, 2012s, 2013s, multiple views, screencap"
BAN_TAGS = [
"2005", # year tags
"2006",
"2007",
"2008",
"2009",
"2010",
"2011",
"2012",
"2013",
"2014",
"2015",
"2016",
"2017",
"2018",
"2019",
"2020",
"dated",
"web address",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
dart = AutoModelForCausalLM.from_pretrained(
DART_V3_REPO_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
use_cache=True,
device_map="cpu",
)
dart = dart.eval()
dart = dart.requires_grad_(False)
dart = torch.compile(dart)
tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
BAN_TOKENS = [tokenizer.convert_tokens_to_ids([tag]) for tag in BAN_TAGS]
def load_pipeline():
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLPipeline.from_pretrained(
IMAGE_MODEL_REPO_ID,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
custom_pipeline="lpw_stable_diffusion_xl",
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if CPU_OFFLOAD: # local
pipe.enable_sequential_cpu_offload(gpu_id=0, device=device)
else:
pipe.to(device) # for spaces
return pipe
if torch.cuda.is_available():
pipe = load_pipeline()
print("Loaded pipeline")
else:
pipe = None
def get_aspect_ratio(width: int, height: int) -> str:
ar = math.log2(width / height)
if ar <= -1.25:
return "<|aspect_ratio:too_tall|>"
elif ar <= -0.75:
return "<|aspect_ratio:tall_wallpaper|>"
elif ar <= -0.25:
return "<|aspect_ratio:tall|>"
elif ar < 0.25:
return "<|aspect_ratio:square|>"
elif ar < 0.75:
return "<|aspect_ratio:wide|>"
elif ar < 1.25:
return "<|aspect_ratio:wide_wallpaper|>"
else:
return "<|aspect_ratio:too_wide|>"
@torch.inference_mode
def generate_prompt(subject: str, aspect_ratio: str):
input_ids = tokenizer.encode_plus(
TEMPLATE.format(aspect_ratio=aspect_ratio, subject=subject),
return_tensors="pt",
).input_ids
print("input_ids:", input_ids)
output_ids = dart.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
temperature=1.0,
top_p=1.0,
top_k=100,
num_beams=1,
bad_words_ids=BAN_TOKENS,
)[0]
generated = output_ids[len(input_ids) :]
decoded = ", ".join(
[
token
for token in tokenizer.batch_decode(generated, skip_special_tokens=True)
if token.strip() != ""
]
)
print("decoded:", decoded)
return decoded
def format_prompt(prompt: str, prompt_suffix: str):
return f"{prompt}, {prompt_suffix}"
@spaces.GPU(duration=20)
@torch.inference_mode
def generate_image(
prompt: str,
negative_prompt: str,
generator,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
):
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def on_generate(
subject: str,
suffix: str,
negative_prompt: str,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
ar = get_aspect_ratio(width, height)
print("ar:", ar)
prompt = generate_prompt(subject, ar)
prompt = format_prompt(prompt, suffix)
print(prompt)
image = generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
)
return image, prompt, seed
def on_retry(
prompt: str,
negative_prompt: str,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
print(prompt)
image = generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
)
return image, prompt, seed
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# IllustriousXL Random Gacha
Image model: [IllustriousXL v0.1](https://huggingface.co/OnomaAIResearch/Illustrious-xl-early-release-v0)
""")
with gr.Row():
subject_radio = gr.Dropdown(
label="Subject",
choices=["1girl", "2girls", "1boy", "no humans"],
value="1girl",
)
run_button = gr.Button("Pull gacha", variant="primary", scale=0)
result = gr.Image(label="Gacha result", show_label=False)
with gr.Accordion("Generation details", open=False):
with gr.Row():
prompt_txt = gr.Textbox(label="Generated prompt", interactive=False)
retry_button = gr.Button("🔄 Retry", scale=0)
with gr.Accordion("Advanced Settings", open=False):
prompt_suffix = gr.Text(
label="Prompt suffix",
visible=True,
value=QUALITY_TAGS,
)
negative_prompt = gr.Text(
label="Negative prompt",
placeholder="Enter a negative prompt",
visible=True,
value=NEGATIVE_PROMPT,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=640,
maximum=MAX_IMAGE_SIZE,
step=64,
value=960, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=640,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1344, # Replace with defaults that work for your model
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=6.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=20,
maximum=40,
step=1,
value=28,
)
gr.on(
triggers=[run_button.click],
fn=on_generate,
inputs=[
subject_radio,
prompt_suffix,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, prompt_txt, seed],
)
gr.on(
triggers=[retry_button.click],
fn=on_retry,
inputs=[
prompt_txt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, prompt_txt, seed],
)
demo.queue().launch()