import spaces import os import random import math import torch torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = True import numpy as np from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_euler_ancestral_discrete import ( EulerAncestralDiscreteScheduler, ) from diffusers.models.attention_processor import AttnProcessor2_0 from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr try: from dotenv import load_dotenv load_dotenv() except: print("failed to import dotenv (this is not a problem on the production)") HF_TOKEN = os.environ.get("HF_TOKEN") assert HF_TOKEN is not None IMAGE_MODEL_REPO_ID = os.environ.get( "IMAGE_MODEL_REPO_ID", "OnomaAIResearch/Illustrious-xl-early-release-v0" ) DART_V3_REPO_ID = os.environ.get("DART_V3_REPO_ID", None) assert DART_V3_REPO_ID is not None CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "False").lower() == "true" MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 TEMPLATE = ( "<|bos|>" # "<|rating:general|>" "{aspect_ratio}" "<|length:medium|>" # "" # "" # "{subject}" ) QUALITY_TAGS = "" NEGATIVE_PROMPT = "bad quality, worst quality, lowres, bad anatomy, sketch, jpeg artifacts, ugly, poorly drawn, signature, watermark, bad anatomy, bad hands, bad feet, retro, old, 2000s, 2010s, 2011s, 2012s, 2013s, multiple views, screencap" BAN_TAGS = [ "2005", # year tags "2006", "2007", "2008", "2009", "2010", "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019", "2020", "dated", "web address", ] device = "cuda" if torch.cuda.is_available() else "cpu" dart = AutoModelForCausalLM.from_pretrained( DART_V3_REPO_ID, torch_dtype=torch.bfloat16, token=HF_TOKEN, use_cache=True, device_map="cpu", ) dart = dart.eval() dart = dart.requires_grad_(False) dart = torch.compile(dart) tokenizer = AutoTokenizer.from_pretrained(DART_V3_REPO_ID) BAN_TOKENS = [tokenizer.convert_tokens_to_ids([tag]) for tag in BAN_TAGS] def load_pipeline(): vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) pipe = StableDiffusionXLPipeline.from_pretrained( IMAGE_MODEL_REPO_ID, vae=vae, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False, custom_pipeline="lpw_stable_diffusion_xl", ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) if CPU_OFFLOAD: # local pipe.enable_sequential_cpu_offload(gpu_id=0, device=device) else: pipe.to(device) # for spaces return pipe if torch.cuda.is_available(): pipe = load_pipeline() print("Loaded pipeline") else: pipe = None def get_aspect_ratio(width: int, height: int) -> str: ar = math.log2(width / height) if ar <= -1.25: return "<|aspect_ratio:too_tall|>" elif ar <= -0.75: return "<|aspect_ratio:tall_wallpaper|>" elif ar <= -0.25: return "<|aspect_ratio:tall|>" elif ar < 0.25: return "<|aspect_ratio:square|>" elif ar < 0.75: return "<|aspect_ratio:wide|>" elif ar < 1.25: return "<|aspect_ratio:wide_wallpaper|>" else: return "<|aspect_ratio:too_wide|>" @torch.inference_mode def generate_prompt(subject: str, aspect_ratio: str): input_ids = tokenizer.encode_plus( TEMPLATE.format(aspect_ratio=aspect_ratio, subject=subject), return_tensors="pt", ).input_ids print("input_ids:", input_ids) output_ids = dart.generate( input_ids, max_new_tokens=256, do_sample=True, temperature=1.0, top_p=1.0, top_k=100, num_beams=1, bad_words_ids=BAN_TOKENS, )[0] generated = output_ids[len(input_ids) :] decoded = ", ".join( [ token for token in tokenizer.batch_decode(generated, skip_special_tokens=True) if token.strip() != "" ] ) print("decoded:", decoded) return decoded def format_prompt(prompt: str, prompt_suffix: str): return f"{prompt}, {prompt_suffix}" @spaces.GPU(duration=20) @torch.inference_mode def generate_image( prompt: str, negative_prompt: str, generator, width: int, height: int, guidance_scale: float, num_inference_steps: int, ): image = pipe( prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images[0] return image def on_generate( subject: str, suffix: str, negative_prompt: str, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) ar = get_aspect_ratio(width, height) print("ar:", ar) prompt = generate_prompt(subject, ar) prompt = format_prompt(prompt, suffix) print(prompt) image = generate_image( prompt, negative_prompt, generator, width, height, guidance_scale, num_inference_steps, ) return image, prompt, seed def on_retry( prompt: str, negative_prompt: str, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator().manual_seed(seed) print(prompt) image = generate_image( prompt, negative_prompt, generator, width, height, guidance_scale, num_inference_steps, ) return image, prompt, seed css = """ #col-container { margin: 0 auto; max-width: 640px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # IllustriousXL Random Gacha Image model: [IllustriousXL v0.1](https://huggingface.co/OnomaAIResearch/Illustrious-xl-early-release-v0) """) with gr.Row(): subject_radio = gr.Dropdown( label="Subject", choices=["1girl", "2girls", "1boy", "no humans"], value="1girl", ) run_button = gr.Button("Pull gacha", variant="primary", scale=0) result = gr.Image(label="Gacha result", show_label=False) with gr.Accordion("Generation details", open=False): with gr.Row(): prompt_txt = gr.Textbox(label="Generated prompt", interactive=False) retry_button = gr.Button("🔄 Retry", scale=0) with gr.Accordion("Advanced Settings", open=False): prompt_suffix = gr.Text( label="Prompt suffix", visible=True, value=QUALITY_TAGS, ) negative_prompt = gr.Text( label="Negative prompt", placeholder="Enter a negative prompt", visible=True, value=NEGATIVE_PROMPT, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=640, maximum=MAX_IMAGE_SIZE, step=64, value=960, # Replace with defaults that work for your model ) height = gr.Slider( label="Height", minimum=640, maximum=MAX_IMAGE_SIZE, step=64, value=1344, # Replace with defaults that work for your model ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=1.0, maximum=10.0, step=0.5, value=6.5, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=20, maximum=40, step=1, value=28, ) gr.on( triggers=[run_button.click], fn=on_generate, inputs=[ subject_radio, prompt_suffix, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, prompt_txt, seed], ) gr.on( triggers=[retry_button.click], fn=on_retry, inputs=[ prompt_txt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, ], outputs=[result, prompt_txt, seed], ) demo.queue().launch()