diffusion / app.py
adamelliotfields's picture
Terminal progress bar improvements
d179c4c verified
raw
history blame
19.8 kB
import argparse
import json
import os
import random
import gradio as gr
from lib import (
Config,
async_call,
disable_progress_bars,
download_civit_file,
download_repo_files,
generate,
read_file,
)
# the CSS `content` attribute expects a string so we need to wrap the number in quotes
refresh_seed_js = """
() => {
const n = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER);
const button = document.getElementById("refresh");
button.style.setProperty("--seed", `"${n}"`);
return n;
}
"""
seed_js = """
(seed) => {
const button = document.getElementById("refresh");
button.style.setProperty("--seed", `"${seed}"`);
return seed;
}
"""
aspect_ratio_js = """
(ar, w, h) => {
if (!ar) return [w, h];
const [width, height] = ar.split(",");
return [parseInt(width), parseInt(height)];
}
"""
def create_image_dropdown(images, locked=False):
if locked:
return gr.Dropdown(
choices=[("🔒", -2)],
interactive=False,
value=-2,
)
else:
return gr.Dropdown(
choices=[("None", -1)] + [(str(i + 1), i) for i, _ in enumerate(images or [])],
interactive=True,
value=-1,
)
async def gallery_fn(images, image, ip_image):
return (
create_image_dropdown(images, locked=image is not None),
create_image_dropdown(images, locked=ip_image is not None),
)
async def image_prompt_fn(images):
return create_image_dropdown(images)
# handle selecting an image from the gallery
# -2 is the lock icon, -1 is None
async def image_select_fn(images, image, i):
if i == -2:
return gr.Image(image)
if i == -1:
return gr.Image(None)
return gr.Image(images[i][0]) if i > -1 else None
async def random_fn():
prompts = read_file("data/prompts.json")
prompts = json.loads(prompts)
return gr.Textbox(value=random.choice(prompts))
async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
if len(args) > 0:
prompt = args[0]
else:
prompt = None
if prompt is None or prompt.strip() == "":
raise gr.Error("You must enter a prompt")
DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT = args[-2:]
gen_args = list(args[:-2])
if DISABLE_IMAGE_PROMPT:
gen_args[2] = None
if DISABLE_IP_IMAGE_PROMPT:
gen_args[3] = None
try:
if Config.ZERO_GPU:
progress((0, 100), desc="ZeroGPU init")
images = await async_call(
generate,
*gen_args,
Error=gr.Error,
Info=gr.Info,
progress=progress,
)
except RuntimeError:
raise gr.Error("Error: Please try again")
return images
with gr.Blocks(
head=read_file("./partials/head.html"),
css="./app.css",
js="./app.js",
theme=gr.themes.Default(
# colors
neutral_hue=gr.themes.colors.gray,
primary_hue=gr.themes.colors.orange,
secondary_hue=gr.themes.colors.blue,
# sizing
text_size=gr.themes.sizes.text_md,
radius_size=gr.themes.sizes.radius_sm,
spacing_size=gr.themes.sizes.spacing_md,
# fonts
font=[gr.themes.GoogleFont("Inter"), *Config.SANS_FONTS],
font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *Config.MONO_FONTS],
).set(
layout_gap="8px",
block_shadow="0 0 #0000",
block_shadow_dark="0 0 #0000",
block_background_fill=gr.themes.colors.gray.c50,
block_background_fill_dark=gr.themes.colors.gray.c900,
),
) as demo:
# override image inputs without clearing them
DISABLE_IMAGE_PROMPT = gr.State(False)
DISABLE_IP_IMAGE_PROMPT = gr.State(False)
gr.HTML(read_file("./partials/intro.html"))
with gr.Tabs():
with gr.TabItem("🏠 Text"):
with gr.Column():
with gr.Group():
output_images = gr.Gallery(
elem_classes=["gallery"],
show_share_button=False,
object_fit="cover",
interactive=False,
show_label=False,
label="Output",
format="png",
columns=2,
)
prompt = gr.Textbox(
placeholder="What do you want to see?",
autoscroll=False,
show_label=False,
label="Prompt",
max_lines=3,
lines=3,
)
# Buttons
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
random_btn = gr.Button(
elem_classes=["icon-button", "popover"],
variant="secondary",
elem_id="random",
min_width=0,
value="🎲",
)
refresh_btn = gr.Button(
elem_classes=["icon-button", "popover"],
variant="secondary",
elem_id="refresh",
min_width=0,
value="🔄",
)
clear_btn = gr.ClearButton(
elem_classes=["icon-button", "popover"],
components=[output_images],
variant="secondary",
elem_id="clear",
min_width=0,
value="🗑️",
)
# img2img tab
with gr.TabItem("🖼️ Image"):
with gr.Group():
with gr.Row():
image_prompt = gr.Image(
show_share_button=False,
label="Initial Image",
min_width=320,
format="png",
type="pil",
)
ip_image_prompt = gr.Image(
show_share_button=False,
label="IP-Adapter Image",
min_width=320,
format="png",
type="pil",
)
with gr.Row():
image_select = gr.Dropdown(
info="Use an initial image from the gallery",
choices=[("None", -1)],
label="Gallery Image",
interactive=True,
filterable=False,
value=-1,
)
ip_image_select = gr.Dropdown(
info="Use an IP-Adapter image from the gallery",
label="Gallery Image (IP-Adapter)",
choices=[("None", -1)],
interactive=True,
filterable=False,
value=-1,
)
with gr.Row():
denoising_strength = gr.Slider(
value=Config.DENOISING_STRENGTH,
label="Denoising Strength",
minimum=0.0,
maximum=1.0,
step=0.1,
)
with gr.Row():
disable_image = gr.Checkbox(
elem_classes=["checkbox"],
label="Disable Initial Image",
value=False,
)
disable_ip_image = gr.Checkbox(
elem_classes=["checkbox"],
label="Disable IP-Adapter Image",
value=False,
)
ip_face = gr.Checkbox(
elem_classes=["checkbox"],
label="Use IP-Adapter Face",
value=False,
)
# img2img tab
with gr.TabItem("🎮 Control"):
gr.Markdown(
"[ControlNet](https://github.com/lllyasviel/ControlNet) with [preprocessors](https://github.com/huggingface/controlnet_aux) coming soon!"
)
with gr.TabItem("⚙️ Menu"):
with gr.Group():
negative_prompt = gr.Textbox(
value="nsfw+",
label="Negative Prompt",
lines=2,
)
with gr.Row():
model = gr.Dropdown(
choices=Config.MODELS,
filterable=False,
value=Config.MODEL,
label="Model",
min_width=240,
)
scheduler = gr.Dropdown(
choices=Config.SCHEDULERS.keys(),
value=Config.SCHEDULER,
elem_id="scheduler",
label="Scheduler",
filterable=False,
)
with gr.Row():
styles = json.loads(read_file("data/styles.json"))
style_ids = list(styles.keys())
style_ids = [sid for sid in style_ids if not sid.startswith("_")]
style = gr.Dropdown(
value=Config.STYLE,
label="Style",
min_width=240,
choices=[("None", "none")]
+ [(styles[sid]["name"], sid) for sid in style_ids],
)
embeddings = gr.Dropdown(
elem_id="embeddings",
label="Embeddings",
choices=[(f"<{e}>", e) for e in Config.EMBEDDINGS],
multiselect=True,
value=[Config.EMBEDDING],
min_width=240,
)
with gr.Row():
with gr.Group(elem_classes=["gap-0"]):
lora_1 = gr.Dropdown(
min_width=240,
label="LoRA #1",
value="none",
choices=[("None", "none")]
+ [
(lora["name"], lora_id)
for lora_id, lora in Config.CIVIT_LORAS.items()
],
)
lora_1_weight = gr.Slider(
value=0.0,
minimum=0.0,
maximum=1.0,
step=0.1,
show_label=False,
)
with gr.Group(elem_classes=["gap-0"]):
lora_2 = gr.Dropdown(
min_width=240,
label="LoRA #2",
value="none",
choices=[("None", "none")]
+ [
(lora["name"], lora_id)
for lora_id, lora in Config.CIVIT_LORAS.items()
],
)
lora_2_weight = gr.Slider(
value=0.0,
minimum=0.0,
maximum=1.0,
step=0.1,
show_label=False,
)
with gr.Row():
guidance_scale = gr.Slider(
value=Config.GUIDANCE_SCALE,
label="Guidance Scale",
minimum=1.0,
maximum=15.0,
step=0.1,
)
inference_steps = gr.Slider(
value=Config.INFERENCE_STEPS,
label="Inference Steps",
minimum=1,
maximum=50,
step=1,
)
deepcache_interval = gr.Slider(
value=Config.DEEPCACHE_INTERVAL,
label="DeepCache",
minimum=1,
maximum=4,
step=1,
)
with gr.Row():
width = gr.Slider(
value=Config.WIDTH,
label="Width",
minimum=256,
maximum=768,
step=32,
)
height = gr.Slider(
value=Config.HEIGHT,
label="Height",
minimum=256,
maximum=768,
step=32,
)
aspect_ratio = gr.Dropdown(
value=f"{Config.WIDTH},{Config.HEIGHT}",
label="Aspect Ratio",
filterable=False,
choices=[
("Custom", None),
("4:7 (384x672)", "384,672"),
("7:9 (448x576)", "448,576"),
("1:1 (512x512)", "512,512"),
("9:7 (576x448)", "576,448"),
("7:4 (672x384)", "672,384"),
],
)
with gr.Row():
file_format = gr.Dropdown(
choices=["png", "jpeg", "webp"],
label="File Format",
filterable=False,
value="png",
)
num_images = gr.Dropdown(
choices=list(range(1, 5)),
value=Config.NUM_IMAGES,
filterable=False,
label="Images",
)
scale = gr.Dropdown(
choices=[(f"{s}x", s) for s in Config.SCALES],
filterable=False,
value=Config.SCALE,
label="Scale",
)
seed = gr.Number(
value=Config.SEED,
label="Seed",
minimum=-1,
maximum=(2**64) - 1,
)
with gr.Row():
use_karras = gr.Checkbox(
elem_classes=["checkbox"],
label="Karras σ",
value=True,
)
use_taesd = gr.Checkbox(
elem_classes=["checkbox"],
label="Tiny VAE",
value=False,
)
use_freeu = gr.Checkbox(
elem_classes=["checkbox"],
label="FreeU",
value=False,
)
use_clip_skip = gr.Checkbox(
elem_classes=["checkbox"],
label="Clip skip",
value=False,
)
random_btn.click(random_fn, inputs=[], outputs=[prompt], show_api=False)
refresh_btn.click(None, inputs=[], outputs=[seed], js=refresh_seed_js)
seed.change(None, inputs=[seed], outputs=[], js=seed_js)
file_format.change(
lambda f: (gr.Gallery(format=f), gr.Image(format=f), gr.Image(format=f)),
inputs=[file_format],
outputs=[output_images, image_prompt, ip_image_prompt],
show_api=False,
)
# input events are only user input; change events are both user and programmatic
aspect_ratio.input(
None,
inputs=[aspect_ratio, width, height],
outputs=[width, height],
js=aspect_ratio_js,
)
# lock the input images so you don't lose them when the gallery updates
output_images.change(
gallery_fn,
inputs=[output_images, image_prompt, ip_image_prompt],
outputs=[image_select, ip_image_select],
show_api=False,
)
# show the selected image in the image input
image_select.change(
image_select_fn,
inputs=[output_images, image_prompt, image_select],
outputs=[image_prompt],
show_api=False,
)
ip_image_select.change(
image_select_fn,
inputs=[output_images, ip_image_prompt, ip_image_select],
outputs=[ip_image_prompt],
show_api=False,
)
# reset the dropdown on clear
image_prompt.clear(
image_prompt_fn,
inputs=[output_images],
outputs=[image_select],
show_api=False,
)
ip_image_prompt.clear(
image_prompt_fn,
inputs=[output_images],
outputs=[ip_image_select],
show_api=False,
)
# show "Custom" aspect ratio when manually changing width or height
gr.on(
triggers=[width.input, height.input],
fn=None,
inputs=[],
outputs=[aspect_ratio],
js="() => { return null; }",
)
# toggle image prompts by updating session state
gr.on(
triggers=[disable_image.input, disable_ip_image.input],
fn=lambda disable_image, disable_ip_image: (disable_image, disable_ip_image),
inputs=[disable_image, disable_ip_image],
outputs=[DISABLE_IMAGE_PROMPT, DISABLE_IP_IMAGE_PROMPT],
)
# generate images
gr.on(
triggers=[generate_btn.click, prompt.submit],
fn=generate_fn,
api_name="generate",
outputs=[output_images],
inputs=[
prompt,
negative_prompt,
image_prompt,
ip_image_prompt,
ip_face,
lora_1,
lora_1_weight,
lora_2,
lora_2_weight,
embeddings,
style,
seed,
model,
scheduler,
width,
height,
guidance_scale,
inference_steps,
denoising_strength,
deepcache_interval,
scale,
num_images,
use_karras,
use_taesd,
use_freeu,
use_clip_skip,
DISABLE_IMAGE_PROMPT,
DISABLE_IP_IMAGE_PROMPT,
],
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
parser.add_argument("-s", "--server", type=str, metavar="STR", default="0.0.0.0")
parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
args = parser.parse_args()
disable_progress_bars()
for repo_id, allow_patterns in Config.HF_MODELS.items():
download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
# download civit loras
for lora_id, lora in Config.CIVIT_LORAS.items():
file_path = os.path.join(os.path.dirname(__file__), "loras")
download_civit_file(
lora_id,
lora["model_version_id"],
file_path=file_path,
token=Config.CIVIT_TOKEN,
)
# https://www.gradio.app/docs/gradio/interface#interface-queue
demo.queue(default_concurrency_limit=1).launch(
server_name=args.server,
server_port=args.port,
)