Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import random | |
import math | |
import torch | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.matmul.allow_tf32 = True | |
import numpy as np | |
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL | |
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( | |
StableDiffusionXLPipeline, | |
) | |
from diffusers.schedulers.scheduling_euler_ancestral_discrete import ( | |
EulerAncestralDiscreteScheduler, | |
) | |
from diffusers.models.attention_processor import AttnProcessor2_0 | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
try: | |
from dotenv import load_dotenv | |
load_dotenv() | |
except: | |
print("failed to import dotenv (this is not a problem on the production)") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
assert HF_TOKEN is not None | |
IMAGE_MODEL_REPO_ID = os.environ.get( | |
"IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0" | |
) | |
DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None) | |
assert DART_V3_REPO_ID is not None | |
CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "False").lower() == "true" | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
TEMPLATE = ( | |
"<|bos|>" | |
# | |
"<|rating:general|>" | |
"{aspect_ratio}" | |
"<|length:medium|>" | |
# | |
"<copyright></copyright>" | |
# | |
"<character></character>" | |
# | |
"<general>{subject}" | |
) | |
QUALITY_TAGS = "" | |
NEGATIVE_PROMPT = "bad quality, worst quality, lowres, bad anatomy, sketch, jpeg artifacts, ugly, poorly drawn, signature, watermark, bad anatomy, bad hands, bad feet, retro, old, 2000s, 2010s, 2011s, 2012s, 2013s, multiple views, screencap" | |
BAN_TAGS = [ | |
"2005", # year tags | |
"2006", | |
"2007", | |
"2008", | |
"2009", | |
"2010", | |
"2011", | |
"2012", | |
"2013", | |
"2014", | |
"2015", | |
"2016", | |
"2017", | |
"2018", | |
"2019", | |
"2020", | |
"dated", | |
"web address", | |
] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dart = AutoModelForCausalLM.from_pretrained( | |
DART_V3_REPO_ID, | |
torch_dtype=torch.bfloat16, | |
token=HF_TOKEN, | |
use_cache=True, | |
device_map="cpu", | |
) | |
dart = dart.eval() | |
dart = dart.requires_grad_(False) | |
dart = torch.compile(dart) | |
tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID) | |
BAN_TOKENS = [tokenizer.convert_tokens_to_ids([tag]) for tag in BAN_TAGS] | |
def load_pipeline(): | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", | |
torch_dtype=torch.float16, | |
) | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
IMAGE_MODEL_REPO_ID, | |
vae=vae, | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
add_watermarker=False, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
) | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
if CPU_OFFLOAD: # local | |
pipe.enable_sequential_cpu_offload(gpu_id=0, device=device) | |
else: | |
pipe.to(device) # for spaces | |
return pipe | |
if torch.cuda.is_available(): | |
pipe = load_pipeline() | |
print("Loaded pipeline") | |
else: | |
pipe = None | |
def get_aspect_ratio(width: int, height: int) -> str: | |
ar = math.log2(width / height) | |
if ar <= -1.25: | |
return "<|aspect_ratio:too_tall|>" | |
elif ar <= -0.75: | |
return "<|aspect_ratio:tall_wallpaper|>" | |
elif ar <= -0.25: | |
return "<|aspect_ratio:tall|>" | |
elif ar < 0.25: | |
return "<|aspect_ratio:square|>" | |
elif ar < 0.75: | |
return "<|aspect_ratio:wide|>" | |
elif ar < 1.25: | |
return "<|aspect_ratio:wide_wallpaper|>" | |
else: | |
return "<|aspect_ratio:too_wide|>" | |
def generate_prompt(subject: str, aspect_ratio: str): | |
input_ids = tokenizer.encode_plus( | |
TEMPLATE.format(aspect_ratio=aspect_ratio, subject=subject), | |
return_tensors="pt", | |
).input_ids | |
print("input_ids:", input_ids) | |
output_ids = dart.generate( | |
input_ids, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=1.0, | |
top_p=1.0, | |
top_k=100, | |
num_beams=1, | |
bad_words_ids=BAN_TOKENS, | |
)[0] | |
generated = output_ids[len(input_ids) :] | |
decoded = ", ".join( | |
[ | |
token | |
for token in tokenizer.batch_decode(generated, skip_special_tokens=True) | |
if token.strip() != "" | |
] | |
) | |
print("decoded:", decoded) | |
return decoded | |
def format_prompt(prompt: str, prompt_suffix: str): | |
return f"{prompt}, {prompt_suffix}" | |
def generate_image( | |
prompt: str, | |
negative_prompt: str, | |
generator, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
): | |
image = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
return image | |
def on_generate( | |
subject: str, | |
suffix: str, | |
negative_prompt: str, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
ar = get_aspect_ratio(width, height) | |
print("ar:", ar) | |
prompt = generate_prompt(subject, ar) | |
prompt = format_prompt(prompt, suffix) | |
print(prompt) | |
image = generate_image( | |
prompt, | |
negative_prompt, | |
generator, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
) | |
return image, prompt, seed | |
def on_retry( | |
prompt: str, | |
negative_prompt: str, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
print(prompt) | |
image = generate_image( | |
prompt, | |
negative_prompt, | |
generator, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
) | |
return image, prompt, seed | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(""" | |
# IllustriousXL Random Gacha | |
Image model: [IllustriousXL v0.1](https://huggingface.co/OnomaAIResearch/Illustrious-xl-early-release-v0) | |
""") | |
with gr.Row(): | |
subject_radio = gr.Dropdown( | |
label="Subject", | |
choices=["1girl", "2girls", "1boy", "no humans"], | |
value="1girl", | |
) | |
run_button = gr.Button("Pull gacha", variant="primary", scale=0) | |
result = gr.Image(label="Gacha result", show_label=False) | |
with gr.Accordion("Generation details", open=False): | |
with gr.Row(): | |
prompt_txt = gr.Textbox(label="Generated prompt", interactive=False) | |
retry_button = gr.Button("π Retry", scale=0) | |
with gr.Accordion("Advanced Settings", open=False): | |
prompt_suffix = gr.Text( | |
label="Prompt suffix", | |
visible=True, | |
value=QUALITY_TAGS, | |
) | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
placeholder="Enter a negative prompt", | |
visible=True, | |
value=NEGATIVE_PROMPT, | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=640, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=960, # Replace with defaults that work for your model | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=640, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1344, # Replace with defaults that work for your model | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=1.0, | |
maximum=10.0, | |
step=0.5, | |
value=6.5, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=20, | |
maximum=40, | |
step=1, | |
value=28, | |
) | |
gr.on( | |
triggers=[run_button.click], | |
fn=on_generate, | |
inputs=[ | |
subject_radio, | |
prompt_suffix, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=[result, prompt_txt, seed], | |
) | |
gr.on( | |
triggers=[retry_button.click], | |
fn=on_retry, | |
inputs=[ | |
prompt_txt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=[result, prompt_txt, seed], | |
) | |
demo.queue().launch() | |