import spaces
import os
import random
import math
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
import numpy as np
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
StableDiffusionXLPipeline,
)
from diffusers.schedulers.scheduling_euler_ancestral_discrete import (
EulerAncestralDiscreteScheduler,
)
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
try:
from dotenv import load_dotenv
load_dotenv()
except:
print("failed to import dotenv (this is not a problem on the production)")
HF_TOKEN = os.environ.get("HF_TOKEN")
assert HF_TOKEN is not None
IMAGE_MODEL_REPO_ID = os.environ.get(
"IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0"
)
DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None)
assert DART_V3_REPO_ID is not None
CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "False").lower() == "true"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
TEMPLATE = (
"<|bos|>"
#
"<|rating:general|>"
"{aspect_ratio}"
"<|length:medium|>"
#
""
#
""
#
"{subject}"
)
QUALITY_TAGS = "masterpiece, best quality, very aesthetic"
NEGATIVE_PROMPT = "(worst quality, bad quality:1.1), very displeasing, lowres, jaggy lines, 3d, blurry, watermark, signature, copyright notice, logo, scan, jpeg artifacts, chromatic aberration, white outline, film grain, artistic error, bad anatomy, bad hands, wrong hand, 2010s, 2011s, 2012s, 2013s"
BAN_TAGS = [
"photoshop (medium)",
"clip studio paint (medium)",
"absurdres",
"highres",
"copyright request",
"character request",
"creature request",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
dart = AutoModelForCausalLM.from_pretrained(
DART_V3_REPO_ID,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
use_cache=True,
device_map="cpu",
)
dart = dart.eval()
dart = dart.requires_grad_(False)
dart = torch.compile(dart)
tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID)
BAN_TOKENS = [tokenizer.convert_tokens_to_ids([tag]) for tag in BAN_TAGS]
def load_pipeline():
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
pipe = StableDiffusionXLPipeline.from_pretrained(
IMAGE_MODEL_REPO_ID,
vae=vae,
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
custom_pipeline="lpw_stable_diffusion_xl",
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if CPU_OFFLOAD: # local
pipe.enable_sequential_cpu_offload(gpu_id=0, device=device)
else:
pipe.to(device) # for spaces
return pipe
if torch.cuda.is_available():
pipe = load_pipeline()
print("Loaded pipeline")
else:
pipe = None
def get_aspect_ratio(width: int, height: int) -> str:
ar = width / height
if ar <= 1 / math.sqrt(3):
return "<|aspect_ratio:ultra_tall|>"
elif ar <= 8 / 9:
return "<|aspect_ratio:tall|>"
elif ar < 9 / 8:
return "<|aspect_ratio:square|>"
elif ar < math.sqrt(3):
return "<|aspect_ratio:wide|>"
else:
return "<|aspect_ratio:ultra_wide|>"
@torch.inference_mode
def generate_prompt(subject: str, aspect_ratio: str):
input_ids = tokenizer.encode_plus(
TEMPLATE.format(aspect_ratio=aspect_ratio, subject=subject),
return_tensors="pt",
).input_ids
print("input_ids:", input_ids)
output_ids = dart.generate(
input_ids,
max_new_tokens=256,
do_sample=True,
temperature=1.0,
top_p=1.0,
top_k=100,
num_beams=1,
bad_words_ids=BAN_TOKENS,
)[0]
generated = output_ids[len(input_ids) :]
decoded = ", ".join(
[
token
for token in tokenizer.batch_decode(generated, skip_special_tokens=True)
if token.strip() != ""
]
)
print("decoded:", decoded)
return decoded
def format_prompt(prompt: str, prompt_suffix: str):
return f"{prompt}, {prompt_suffix}"
@spaces.GPU(duration=20)
@torch.inference_mode
def generate_image(
prompt: str,
negative_prompt: str,
generator,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
):
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
def on_generate(
subject: str,
suffix: str,
negative_prompt: str,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
ar = get_aspect_ratio(width, height)
print("ar:", ar)
prompt = generate_prompt(subject, ar)
prompt = format_prompt(prompt, suffix)
print(prompt)
image = generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
)
return image, prompt, seed
def on_retry(
prompt: str,
negative_prompt: str,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
print(prompt)
image = generate_image(
prompt,
negative_prompt,
generator,
width,
height,
guidance_scale,
num_inference_steps,
)
return image, prompt, seed
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("""
# IllustriousXL Random Gacha
Image model: [IllustriousXL v0.1](https://huggingface.co/OnomaAIResearch/Illustrious-xl-early-release-v0)
""")
with gr.Row():
subject_radio = gr.Dropdown(
label="Subject",
choices=["1girl", "2girls", "1boy", "no humans"],
value="1girl",
)
run_button = gr.Button("Pull gacha", variant="primary", scale=0)
result = gr.Image(label="Gacha result", show_label=False)
with gr.Accordion("Generation details", open=False):
with gr.Row():
prompt_txt = gr.Textbox(label="Generated prompt", interactive=False)
retry_button = gr.Button("🔄 Retry", scale=0)
with gr.Accordion("Advanced Settings", open=False):
prompt_suffix = gr.Text(
label="Prompt suffix",
visible=True,
value=QUALITY_TAGS,
)
negative_prompt = gr.Text(
label="Negative prompt",
placeholder="Enter a negative prompt",
visible=True,
value=NEGATIVE_PROMPT,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=640,
maximum=MAX_IMAGE_SIZE,
step=64,
value=960, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=640,
maximum=MAX_IMAGE_SIZE,
step=64,
value=1344, # Replace with defaults that work for your model
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=10.0,
step=0.5,
value=6.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=20,
maximum=40,
step=1,
value=28,
)
gr.on(
triggers=[run_button.click],
fn=on_generate,
inputs=[
subject_radio,
prompt_suffix,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, prompt_txt, seed],
)
gr.on(
triggers=[retry_button.click],
fn=on_retry,
inputs=[
prompt_txt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[result, prompt_txt, seed],
)
demo.queue().launch()