John6666's picture
Upload 12 files
35b1cf8 verified
raw
history blame contribute delete
No virus
46.6 kB
import spaces
import gradio as gr
import json
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
from diffusers.utils import load_image
from diffusers import FluxControlNetPipeline, FluxControlNetModel, FluxMultiControlNetModel, FluxControlNetImg2ImgPipeline
from huggingface_hub import HfFileSystem, ModelCard
import random
import time
from env import models, num_loras, num_cns
from mod import (clear_cache, get_repo_safetensors, is_repo_name, is_repo_exists, get_model_trigger,
description_ui, compose_lora_json, is_valid_lora, fuse_loras, save_image, preprocess_i2i_image,
get_trigger_word, enhance_prompt, deselect_lora, set_control_union_image,
get_control_union_mode, set_control_union_mode, get_control_params, translate_to_en)
from flux import (search_civitai_lora, select_civitai_lora, search_civitai_lora_json,
download_my_lora, get_all_lora_tupled_list, apply_lora_prompt,
update_loras, get_t2i_model_info)
from tagger.tagger import predict_tags_wd, compose_prompt_to_copy
from tagger.fl2flux import predict_tags_fl2_flux
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
base_model = models[0]
controlnet_model_union_repo = 'InstantX/FLUX.1-dev-Controlnet-Union'
#controlnet_model_union_repo = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
dtype = torch.bfloat16
#dtype = torch.float8_e4m3fn
#device = "cuda" if torch.cuda.is_available() else "cpu"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1)
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(base_model, vae=good_vae, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
controlnet_union = None
controlnet = None
last_model = models[0]
last_cn_on = False
#controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
#controlnet = FluxMultiControlNetModel([controlnet_union])
#controlnet.config = controlnet_union.config
MAX_SEED = 2**32-1
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union
# https://huggingface.co/spaces/jiuface/FLUX.1-dev-Controlnet-Union
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux
#@spaces.GPU()
def change_base_model(repo_id: str, cn_on: bool, disable_model_cache: bool, progress=gr.Progress(track_tqdm=True)):
global pipe, pipe_i2i, taef1, good_vae, controlnet_union, controlnet, last_model, last_cn_on, dtype
try:
if not disable_model_cache and (repo_id == last_model and cn_on is last_cn_on) or not is_repo_name(repo_id) or not is_repo_exists(repo_id): return gr.update(visible=True)
pipe.to("cpu")
pipe_i2i.to("cpu")
good_vae.to("cpu")
taef1.to("cpu")
if controlnet is not None: controlnet.to("cpu")
if controlnet_union is not None: controlnet_union.to("cpu")
clear_cache()
if cn_on:
progress(0, desc=f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
print(f"Loading model: {repo_id} / Loading ControlNet: {controlnet_model_union_repo}")
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union_repo, torch_dtype=dtype)
controlnet = FluxMultiControlNetModel([controlnet_union])
controlnet.config = controlnet_union.config
pipe = FluxControlNetPipeline.from_pretrained(repo_id, controlnet=controlnet, torch_dtype=dtype)
pipe_i2i = FluxControlNetImg2ImgPipeline.from_pretrained(repo_id, controlnet=controlnet, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
last_model = repo_id
last_cn_on = cn_on
progress(1, desc=f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
print(f"Model loaded: {repo_id} / ControlNet Loaded: {controlnet_model_union_repo}")
else:
progress(0, desc=f"Loading model: {repo_id}")
print(f"Loading model: {repo_id}")
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=dtype)
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(repo_id, vae=None, transformer=pipe.transformer, text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer, text_encoder_2=pipe.text_encoder_2, tokenizer_2=pipe.tokenizer_2, torch_dtype=dtype)
last_model = repo_id
last_cn_on = cn_on
progress(1, desc=f"Model loaded: {repo_id}")
print(f"Model loaded: {repo_id}")
except Exception as e:
print(f"Model load Error: {e}")
raise gr.Error(f"Model load Error: {e}") from e
return gr.update(visible=True)
change_base_model.zerogpu = True
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_selection(evt: gr.SelectData, width, height):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
else:
width = 1024
height = 1024
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
width,
height,
)
@spaces.GPU(duration=70)
@torch.inference_mode()
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress=gr.Progress(track_tqdm=True)):
global pipe, taef1, good_vae, controlnet, controlnet_union
try:
good_vae.to("cuda")
taef1.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
with calculateDuration("Generating image"):
# Generate image
modes, images, scales = get_control_params()
if not cn_on or len(modes) == 0:
pipe.to("cuda")
pipe.vae = taef1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
progress(0, desc="Start Inference.")
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt_mash,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
output_type="pil",
good_vae=good_vae,
):
yield img
else:
pipe.to("cuda")
pipe.vae = good_vae
if controlnet_union is not None: controlnet_union.to("cuda")
if controlnet is not None: controlnet.to("cuda")
pipe.enable_model_cpu_offload()
progress(0, desc="Start Inference with ControlNet.")
for img in pipe(
prompt=prompt_mash,
control_image=images,
control_mode=modes,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
controlnet_conditioning_scale=scales,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images:
yield img
except Exception as e:
print(e)
raise gr.Error(f"Inference Error: {e}") from e
@spaces.GPU(duration=70)
@torch.inference_mode()
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, lora_scale, seed, cn_on, progress=gr.Progress(track_tqdm=True)):
global pipe_i2i, good_vae, controlnet, controlnet_union
try:
good_vae.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(int(float(seed)))
image_input = load_image(image_input_path)
with calculateDuration("Generating image"):
# Generate image
modes, images, scales = get_control_params()
if not cn_on or len(modes) == 0:
pipe_i2i.to("cuda")
pipe_i2i.vae = good_vae
image_input = load_image(image_input_path)
progress(0, desc="Start I2I Inference.")
final_image = pipe_i2i(
prompt=prompt_mash,
image=image_input,
strength=image_strength,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
output_type="pil",
).images[0]
return final_image
else:
pipe_i2i.to("cuda")
pipe_i2i.vae = good_vae
image_input = load_image(image_input_path)
if controlnet_union is not None: controlnet_union.to("cuda")
if controlnet is not None: controlnet.to("cuda")
pipe_i2i.enable_model_cpu_offload()
progress(0, desc="Start I2I Inference with ControlNet.")
final_image = pipe_i2i(
prompt=prompt_mash,
control_image=images,
control_mode=modes,
image=image_input,
strength=image_strength,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
controlnet_conditioning_scale=scales,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
output_type="pil",
).images[0]
return final_image
except Exception as e:
print(e)
raise gr.Error(f"I2I Inference Error: {e}") from e
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
lora_scale, lora_json, cn_on, translate_on, progress=gr.Progress(track_tqdm=True)):
global pipe
if selected_index is None and not is_valid_lora(lora_json):
gr.Info("LoRA isn't selected.")
# raise gr.Error("You must select a LoRA before proceeding.")
progress(0, desc="Preparing Inference.")
with calculateDuration("Unloading LoRA"):
try:
pipe.unfuse_lora()
pipe.unload_lora_weights()
pipe_i2i.unfuse_lora()
pipe_i2i.unload_lora_weights()
except Exception as e:
print(e)
clear_cache() #
if translate_on: prompt = translate_to_en(prompt)
prompt_mash = prompt + get_model_trigger(last_model)
if is_valid_lora(lora_json):
# Load External LoRA weights
with calculateDuration("Loading External LoRA weights"):
fuse_loras(pipe, lora_json)
fuse_loras(pipe_i2i, lora_json)
trigger_word = get_trigger_word(lora_json)
prompt_mash = f"{prompt} {trigger_word}"
if selected_index is not None:
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
if(trigger_word):
if "trigger_position" in selected_lora:
if selected_lora["trigger_position"] == "prepend":
prompt_mash = f"{trigger_word} {prompt_mash}"
else:
prompt_mash = f"{prompt_mash} {trigger_word}"
else:
prompt_mash = f"{trigger_word} {prompt_mash}"
else:
prompt_mash = prompt_mash
# Load LoRA weights
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
if(image_input is not None):
if "weights" in selected_lora:
pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe_i2i.load_lora_weights(lora_path)
else:
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
progress(0, desc="Running Inference.")
if(image_input is not None):
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, lora_scale, seed, cn_on, progress)
yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(visible=False)
else:
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, cn_on, progress)
# Consume the generator to get the final image
final_image = None
step_counter = 0
for image in image_generator:
step_counter+=1
final_image = image
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
yield image, seed, gr.update(value=progress_bar, visible=True)
yield save_image(final_image, None, last_model, prompt_mash, height, width, steps, cfg_scale, seed), seed, gr.update(value=progress_bar, visible=False)
def get_huggingface_safetensors(link):
split_link = link.split("/")
if(len(split_link) == 2):
model_card = ModelCard.load(link)
base_model = model_card.data.get("base_model")
print(base_model)
if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
raise Exception("Not a FLUX LoRA!")
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
trigger_word = model_card.data.get("instance_prompt", "")
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
fs = HfFileSystem()
try:
list_of_files = fs.ls(link, detail=False)
for file in list_of_files:
if(file.endswith(".safetensors")):
safetensors_name = file.split("/")[-1]
if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
image_elements = file.split("/")
image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
except Exception as e:
print(e)
gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
return split_link[1], link, safetensors_name, trigger_word, image_url
def check_custom_model(link):
if(link.startswith("https://")):
if(link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co")):
link_split = link.split("huggingface.co/")
return get_huggingface_safetensors(link_split[1])
else:
return get_huggingface_safetensors(link)
def add_custom_lora(custom_lora):
global loras
if(custom_lora):
try:
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
print(f"Loaded custom LoRA: {repo}")
card = f'''
<div class="custom_lora_card">
<span>Loaded custom LoRA:</span>
<div class="card_internal">
<img src="{image}" />
<div>
<h3>{title}</h3>
<small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
</div>
</div>
</div>
'''
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
if(not existing_item_index):
new_item = {
"image": image,
"title": title,
"repo": repo,
"weights": path,
"trigger_word": trigger_word
}
print(new_item)
existing_item_index = len(loras)
loras.append(new_item)
return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
except Exception as e:
gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-FLUX LoRA")
return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-FLUX LoRA"), gr.update(visible=True), gr.update(), "", None, ""
else:
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
def remove_custom_lora():
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
run_lora.zerogpu = True
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
.card_internal{display: flex;height: 100px;margin-top: .5em}
.card_internal img{margin-right: 1em}
.styler{--form-gap-width: 0px !important}
#progress{height:30px}
#progress .generating{display:none}
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
.info {text-align:center; !important}
'''
with gr.Blocks(theme='Nymbo/Nymbo_Theme', fill_width=True, css=css, delete_cache=(60, 3600)) as app:
with gr.Tab("FLUX LoRA the Explorer"):
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA">FLUX LoRA the Explorer Mod</h1>""",
elem_id="title",
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
with gr.Group():
with gr.Accordion("Generate Prompt from Image", open=False):
tagger_image = gr.Image(label="Input image", type="pil", sources=["upload", "clipboard"], height=256)
with gr.Accordion(label="Advanced options", open=False):
tagger_general_threshold = gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.01, interactive=True)
tagger_character_threshold = gr.Slider(label="Character threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.01, interactive=True)
neg_prompt = gr.Text(label="Negative Prompt", lines=1, max_lines=8, placeholder="", visible=False)
v2_character = gr.Textbox(label="Character", placeholder="hatsune miku", scale=2, visible=False)
v2_series = gr.Textbox(label="Series", placeholder="vocaloid", scale=2, visible=False)
v2_copy = gr.Button(value="Copy to clipboard", size="sm", interactive=False, visible=False)
tagger_algorithms = gr.CheckboxGroup(["Use WD Tagger", "Use Florence-2-Flux"], label="Algorithms", value=["Use WD Tagger"])
tagger_generate_from_image = gr.Button(value="Generate Prompt from Image")
prompt = gr.Textbox(label="Prompt", lines=1, max_lines=8, placeholder="Type a prompt", show_copy_button=True)
with gr.Row():
prompt_enhance = gr.Button(value="Enhance your prompt", variant="secondary")
auto_trans = gr.Checkbox(label="Auto translate to English", value=False, elem_classes="info")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column():
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3,
elem_id="gallery"
)
with gr.Group():
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="multimodalart/vintage-ads-flux")
gr.Markdown("[Check the list of FLUX LoRas](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
custom_lora_info = gr.HTML(visible=False)
custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
deselect_lora_button = gr.Button("Deselect LoRA", variant="secondary")
with gr.Column():
progress_bar = gr.Markdown(elem_id="progress",visible=False)
result = gr.Image(label="Generated Image", format="png", show_share_button=False)
with gr.Group():
model_name = gr.Dropdown(label="Base Model", info="You can enter a huggingface model repo_id to want to use.", choices=models, value=models[0], allow_custom_value=True)
model_info = gr.Markdown(elem_classes="info")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
input_image = gr.Image(label="Input image", type="filepath", height=256, sources=["upload", "clipboard"], show_share_button=False)
with gr.Column():
image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
input_image_preprocess = gr.Checkbox(True, label="Preprocess Input image")
with gr.Column():
with gr.Row():
lora_scale = gr.Slider(label="LoRA Scale", minimum=-3, maximum=3, step=0.01, value=0.95)
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
disable_model_cache = gr.Checkbox(False, label="Disable model caching")
with gr.Accordion("External LoRA", open=True):
with gr.Column():
lora_repo_json = gr.JSON(value=[{}] * num_loras, visible=False)
lora_repo = [None] * num_loras
lora_weights = [None] * num_loras
lora_trigger = [None] * num_loras
lora_wt = [None] * num_loras
lora_info = [None] * num_loras
lora_copy = [None] * num_loras
lora_md = [None] * num_loras
lora_num = [None] * num_loras
with gr.Row():
for i in range(num_loras):
with gr.Column():
lora_repo[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Repo", choices=get_all_lora_tupled_list(), info="Input LoRA Repo ID", value="", allow_custom_value=True)
with gr.Row():
lora_weights[i] = gr.Dropdown(label=f"LoRA {int(i+1)} Filename", choices=[], info="Optional", value="", allow_custom_value=True)
lora_trigger[i] = gr.Textbox(label=f"LoRA {int(i+1)} Trigger Prompt", lines=1, max_lines=4, value="")
lora_wt[i] = gr.Slider(label=f"LoRA {int(i+1)} Scale", minimum=-3, maximum=3, step=0.01, value=1.00)
with gr.Row():
lora_info[i] = gr.Textbox(label="", info="Example of prompt:", value="", show_copy_button=True, interactive=False, visible=False)
lora_copy[i] = gr.Button(value="Copy example to prompt", visible=False)
lora_md[i] = gr.Markdown(value="", visible=False)
lora_num[i] = gr.Number(i, visible=False)
with gr.Accordion("From URL", open=True, visible=True):
with gr.Row():
lora_search_civitai_basemodel = gr.CheckboxGroup(label="Search LoRA for", choices=["Flux.1 D", "Flux.1 S"], value=["Flux.1 D", "Flux.1 S"])
lora_search_civitai_sort = gr.Radio(label="Sort", choices=["Highest Rated", "Most Downloaded", "Newest"], value="Highest Rated")
lora_search_civitai_period = gr.Radio(label="Period", choices=["AllTime", "Year", "Month", "Week", "Day"], value="AllTime")
with gr.Row():
lora_search_civitai_query = gr.Textbox(label="Query", placeholder="flux", lines=1)
lora_search_civitai_tag = gr.Textbox(label="Tag", lines=1)
lora_search_civitai_submit = gr.Button("Search on Civitai")
with gr.Row():
lora_search_civitai_json = gr.JSON(value={}, visible=False)
lora_search_civitai_desc = gr.Markdown(value="", visible=False)
lora_search_civitai_result = gr.Dropdown(label="Search Results", choices=[("", "")], value="", allow_custom_value=True, visible=False)
lora_download_url = gr.Textbox(label="LoRA URL", placeholder="https://civitai.com/api/download/models/28907", lines=1)
with gr.Row():
lora_download = [None] * num_loras
for i in range(num_loras):
lora_download[i] = gr.Button(f"Get and set LoRA to {int(i+1)}")
with gr.Accordion("ControlNet (extremely slow)", open=True, visible=True):
with gr.Column():
cn_on = gr.Checkbox(False, label="Use ControlNet")
cn_mode = [None] * num_cns
cn_scale = [None] * num_cns
cn_image = [None] * num_cns
cn_image_ref = [None] * num_cns
cn_res = [None] * num_cns
cn_num = [None] * num_cns
with gr.Row():
for i in range(num_cns):
with gr.Column():
cn_mode[i] = gr.Radio(label=f"ControlNet {int(i+1)} Mode", choices=get_control_union_mode(), value=get_control_union_mode()[0])
with gr.Row():
cn_scale[i] = gr.Slider(label=f"ControlNet {int(i+1)} Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.75)
cn_res[i] = gr.Slider(label=f"ControlNet {int(i+1)} Preprocess resolution", minimum=128, maximum=512, value=384, step=1)
cn_num[i] = gr.Number(i, visible=False)
with gr.Row():
cn_image_ref[i] = gr.Image(label="Image Reference", type="pil", format="png", height=256, sources=["upload", "clipboard"], show_share_button=False)
cn_image[i] = gr.Image(label="Control Image", type="pil", format="png", height=256, show_share_button=False, interactive=False)
gallery.select(
update_selection,
inputs=[width, height],
outputs=[prompt, selected_info, selected_index, width, height],
queue=False,
show_api=False,
trigger_mode="once",
)
custom_lora.input(
add_custom_lora,
inputs=[custom_lora],
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt],
queue=False,
show_api=False,
)
custom_lora_button.click(
remove_custom_lora,
outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora],
queue=False,
show_api=False,
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=change_base_model,
inputs=[model_name, cn_on, disable_model_cache],
outputs=[result],
queue=True,
show_api=False,
trigger_mode="once",
).success(
fn=run_lora,
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
lora_scale, lora_repo_json, cn_on, auto_trans],
outputs=[result, seed, progress_bar],
queue=True,
show_api=True,
)
input_image.upload(preprocess_i2i_image, [input_image, input_image_preprocess, height, width], [input_image], queue=False, show_api=False)
deselect_lora_button.click(deselect_lora, None, [prompt, selected_info, selected_index, width, height], queue=False, show_api=False)
gr.on(
triggers=[model_name.change, cn_on.change],
fn=get_t2i_model_info,
inputs=[model_name],
outputs=[model_info],
queue=False,
show_api=False,
trigger_mode="once",
).then(change_base_model, [model_name, cn_on, disable_model_cache], [result], queue=True, show_api=False)
prompt_enhance.click(enhance_prompt, [prompt], [prompt], queue=False, show_api=False)
gr.on(
triggers=[lora_search_civitai_submit.click, lora_search_civitai_query.submit, lora_search_civitai_tag.submit],
fn=search_civitai_lora,
inputs=[lora_search_civitai_query, lora_search_civitai_basemodel, lora_search_civitai_sort, lora_search_civitai_period, lora_search_civitai_tag],
outputs=[lora_search_civitai_result, lora_search_civitai_desc, lora_search_civitai_submit, lora_search_civitai_query],
scroll_to_output=True,
queue=True,
show_api=False,
)
lora_search_civitai_json.change(search_civitai_lora_json, [lora_search_civitai_query, lora_search_civitai_basemodel], [lora_search_civitai_json], queue=True, show_api=True) # fn for api
lora_search_civitai_result.change(select_civitai_lora, [lora_search_civitai_result], [lora_download_url, lora_search_civitai_desc], scroll_to_output=True, queue=False, show_api=False)
for i, l in enumerate(lora_repo):
deselect_lora_button.click(lambda: ("", 1.0), None, [lora_repo[i], lora_wt[i]], queue=False, show_api=False)
gr.on(
triggers=[lora_download[i].click],
fn=download_my_lora,
inputs=[lora_download_url, lora_repo[i]],
outputs=[lora_repo[i]],
scroll_to_output=True,
queue=True,
show_api=False,
)
gr.on(
triggers=[lora_repo[i].change, lora_wt[i].change],
fn=update_loras,
inputs=[prompt, lora_repo[i], lora_wt[i]],
outputs=[prompt, lora_repo[i], lora_wt[i], lora_info[i], lora_md[i]],
queue=False,
trigger_mode="once",
show_api=False,
).success(get_repo_safetensors, [lora_repo[i]], [lora_weights[i]], queue=False, show_api=False
).success(apply_lora_prompt, [lora_info[i]], [lora_trigger[i]], queue=False, show_api=False
).success(compose_lora_json, [lora_repo_json, lora_num[i], lora_repo[i], lora_wt[i], lora_weights[i], lora_trigger[i]], [lora_repo_json], queue=False, show_api=False)
for i, m in enumerate(cn_mode):
gr.on(
triggers=[cn_mode[i].change, cn_scale[i].change],
fn=set_control_union_mode,
inputs=[cn_num[i], cn_mode[i], cn_scale[i]],
outputs=[cn_on],
queue=True,
show_api=False,
).success(set_control_union_image, [cn_num[i], cn_mode[i], cn_image_ref[i], height, width, cn_res[i]], [cn_image[i]], queue=False, show_api=False)
cn_image_ref[i].upload(set_control_union_image, [cn_num[i], cn_mode[i], cn_image_ref[i], height, width, cn_res[i]], [cn_image[i]], queue=False, show_api=False)
tagger_generate_from_image.click(lambda: ("", "", ""), None, [v2_series, v2_character, prompt], queue=False, show_api=False,
).success(
predict_tags_wd,
[tagger_image, prompt, tagger_algorithms, tagger_general_threshold, tagger_character_threshold],
[v2_series, v2_character, prompt, v2_copy],
show_api=False,
).success(predict_tags_fl2_flux, [tagger_image, prompt, tagger_algorithms], [prompt], show_api=False,
).success(compose_prompt_to_copy, [v2_character, v2_series, prompt], [prompt], queue=False, show_api=False)
with gr.Tab("FLUX Prompt Generator"):
from prompt import (PromptGenerator, HuggingFaceInferenceNode, florence_caption,
ARTFORM, PHOTO_TYPE, ROLES, HAIRSTYLES, LIGHTING, COMPOSITION, POSE, BACKGROUND,
PHOTOGRAPHY_STYLES, DEVICE, PHOTOGRAPHER, ARTIST, DIGITAL_ARTFORM, PLACE,
FEMALE_DEFAULT_TAGS, MALE_DEFAULT_TAGS, FEMALE_BODY_TYPES, MALE_BODY_TYPES,
FEMALE_CLOTHING, MALE_CLOTHING, FEMALE_ADDITIONAL_DETAILS, MALE_ADDITIONAL_DETAILS, pg_title)
prompt_generator = PromptGenerator()
huggingface_node = HuggingFaceInferenceNode()
gr.HTML(pg_title)
with gr.Row():
with gr.Column(scale=2):
with gr.Accordion("Basic Settings"):
pg_custom = gr.Textbox(label="Custom Input Prompt (optional)")
pg_subject = gr.Textbox(label="Subject (optional)")
pg_gender = gr.Radio(["female", "male"], label="Gender", value="female")
# Add the radio button for global option selection
pg_global_option = gr.Radio(
["Disabled", "Random", "No Figure Rand"],
label="Set all options to:",
value="Disabled"
)
with gr.Accordion("Artform and Photo Type", open=False):
pg_artform = gr.Dropdown(["disabled", "random"] + ARTFORM, label="Artform", value="disabled")
pg_photo_type = gr.Dropdown(["disabled", "random"] + PHOTO_TYPE, label="Photo Type", value="disabled")
with gr.Accordion("Character Details", open=False):
pg_body_types = gr.Dropdown(["disabled", "random"] + FEMALE_BODY_TYPES + MALE_BODY_TYPES, label="Body Types", value="disabled")
pg_default_tags = gr.Dropdown(["disabled", "random"] + FEMALE_DEFAULT_TAGS + MALE_DEFAULT_TAGS, label="Default Tags", value="disabled")
pg_roles = gr.Dropdown(["disabled", "random"] + ROLES, label="Roles", value="disabled")
pg_hairstyles = gr.Dropdown(["disabled", "random"] + HAIRSTYLES, label="Hairstyles", value="disabled")
pg_clothing = gr.Dropdown(["disabled", "random"] + FEMALE_CLOTHING + MALE_CLOTHING, label="Clothing", value="disabled")
with gr.Accordion("Scene Details", open=False):
pg_place = gr.Dropdown(["disabled", "random"] + PLACE, label="Place", value="disabled")
pg_lighting = gr.Dropdown(["disabled", "random"] + LIGHTING, label="Lighting", value="disabled")
pg_composition = gr.Dropdown(["disabled", "random"] + COMPOSITION, label="Composition", value="disabled")
pg_pose = gr.Dropdown(["disabled", "random"] + POSE, label="Pose", value="disabled")
pg_background = gr.Dropdown(["disabled", "random"] + BACKGROUND, label="Background", value="disabled")
with gr.Accordion("Style and Artist", open=False):
pg_additional_details = gr.Dropdown(["disabled", "random"] + FEMALE_ADDITIONAL_DETAILS + MALE_ADDITIONAL_DETAILS, label="Additional Details", value="disabled")
pg_photography_styles = gr.Dropdown(["disabled", "random"] + PHOTOGRAPHY_STYLES, label="Photography Styles", value="disabled")
pg_device = gr.Dropdown(["disabled", "random"] + DEVICE, label="Device", value="disabled")
pg_photographer = gr.Dropdown(["disabled", "random"] + PHOTOGRAPHER, label="Photographer", value="disabled")
pg_artist = gr.Dropdown(["disabled", "random"] + ARTIST, label="Artist", value="disabled")
pg_digital_artform = gr.Dropdown(["disabled", "random"] + DIGITAL_ARTFORM, label="Digital Artform", value="disabled")
pg_generate_button = gr.Button("Generate Prompt")
with gr.Column(scale=2):
with gr.Accordion("Image and Caption", open=False):
pg_input_image = gr.Image(label="Input Image (optional)")
pg_caption_output = gr.Textbox(label="Generated Caption", lines=3)
pg_create_caption_button = gr.Button("Create Caption")
pg_add_caption_button = gr.Button("Add Caption to Prompt")
with gr.Accordion("Prompt Generation", open=True):
pg_output = gr.Textbox(label="Generated Prompt / Input Text", lines=4)
pg_t5xxl_output = gr.Textbox(label="T5XXL Output", visible=True)
pg_clip_l_output = gr.Textbox(label="CLIP L Output", visible=True)
pg_clip_g_output = gr.Textbox(label="CLIP G Output", visible=True)
with gr.Column(scale=2):
with gr.Accordion("Prompt Generation with LLM", open=False):
pg_happy_talk = gr.Checkbox(label="Happy Talk", value=True)
pg_compress = gr.Checkbox(label="Compress", value=True)
pg_compression_level = gr.Radio(["soft", "medium", "hard"], label="Compression Level", value="hard")
pg_poster = gr.Checkbox(label="Poster", value=False)
pg_custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5)
pg_generate_text_button = gr.Button("Generate Prompt with LLM (Llama 3.1 70B)")
pg_text_output = gr.Textbox(label="Generated Text", lines=10)
def create_caption(image):
if image is not None:
return florence_caption(image)
return ""
pg_create_caption_button.click(
create_caption,
inputs=[pg_input_image],
outputs=[pg_caption_output]
)
def generate_prompt_with_dynamic_seed(*args):
# Generate a new random seed
dynamic_seed = random.randint(0, 1000000)
# Call the generate_prompt function with the dynamic seed
result = prompt_generator.generate_prompt(dynamic_seed, *args)
# Return the result along with the used seed
return [dynamic_seed] + list(result)
pg_generate_button.click(
generate_prompt_with_dynamic_seed,
inputs=[pg_custom, pg_subject, pg_gender, pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles,
pg_additional_details, pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform,
pg_place, pg_lighting, pg_clothing, pg_composition, pg_pose, pg_background, pg_input_image],
outputs=[gr.Number(label="Used Seed", visible=False), pg_output, gr.Number(visible=False), pg_t5xxl_output, pg_clip_l_output, pg_clip_g_output]
) #
pg_add_caption_button.click(
prompt_generator.add_caption_to_prompt,
inputs=[pg_output, pg_caption_output],
outputs=[pg_output]
)
pg_generate_text_button.click(
huggingface_node.generate,
inputs=[pg_output, pg_happy_talk, pg_compress, pg_compression_level, pg_poster, pg_custom_base_prompt],
outputs=pg_text_output
)
def update_all_options(choice):
updates = {}
if choice == "Disabled":
for dropdown in [
pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
]:
updates[dropdown] = gr.update(value="disabled")
elif choice == "Random":
for dropdown in [
pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
]:
updates[dropdown] = gr.update(value="random")
else: # No Figure Random
for dropdown in [pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing, pg_pose, pg_additional_details]:
updates[dropdown] = gr.update(value="disabled")
for dropdown in [pg_artform, pg_place, pg_lighting, pg_composition, pg_background, pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform]:
updates[dropdown] = gr.update(value="random")
return updates
pg_global_option.change(
update_all_options,
inputs=[pg_global_option],
outputs=[
pg_artform, pg_photo_type, pg_body_types, pg_default_tags, pg_roles, pg_hairstyles, pg_clothing,
pg_place, pg_lighting, pg_composition, pg_pose, pg_background, pg_additional_details,
pg_photography_styles, pg_device, pg_photographer, pg_artist, pg_digital_artform
]
)
description_ui()
gr.LoginButton()
gr.DuplicateButton(value="Duplicate Space for private use (This demo does not work on CPU. Requires GPU Space)")
app.queue()
app.launch()