John6666's picture
Upload app.py
64f7574 verified
raw
history blame
No virus
8.33 kB
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline
import copy
import random
import time
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
models = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell",
"sayakpaul/FLUX.1-merged", "John6666/blue-pencil-flux1-v001-fp8-flux"]
base_model = models[0]
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
MAX_SEED = 2**32-1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def update_selection(evt: gr.SelectData, width, height):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
width = 768
height = 1024
elif selected_lora["aspect"] == "landscape":
width = 1024
height = 768
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
width,
height,
)
@spaces.GPU(duration=70)
def generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress):
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
image = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
lora_scale, lora_repo, lora_weights, lora_trigger, progress=gr.Progress(track_tqdm=True)):
if selected_index is None and not lora_repo:
raise gr.Error("You must select a LoRA before proceeding.")
if selected_index is not None and not lora_repo:
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
else: # override
selected_lora = loras[0]
lora_path = lora_repo
trigger_word = lora_trigger
# Load LoRA weights
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
if lora_weights: # override
pipe.load_lora_weights(lora_path, weight_name=lora_weights)
elif "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
image = generate_image(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, progress)
pipe.to("cpu")
pipe.unload_lora_weights()
return image, seed
run_lora.zerogpu = True
def get_repo_safetensors(repo_id: str):
from huggingface_hub import HfApi
api = HfApi()
try:
if " " in repo_id or not api.repo_exists(repo_id): return gr.update(value="", choices=[])
files = api.list_repo_files(repo_id=repo_id)
except Exception as e:
print(f"Error: Failed to get {repo_id}'s info. ")
print(e)
return gr.update(choices=[])
files = [f for f in files if f.endswith(".safetensors")]
if len(files) == 0: return gr.update(value="", choices=[])
else: return gr.update(value=files[0], choices=files)
def change_base_model(repo_id: str):
from huggingface_hub import HfApi
global pipe
api = HfApi()
try:
if " " in repo_id or not api.repo_exists(repo_id): return
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
except Exception as e:
print(e)
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
elem_id="title",
)
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column(scale=3):
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3,
elem_id="gallery"
)
with gr.Column(scale=4):
result = gr.Image(label="Generated Image")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
with gr.Row():
lora_repo = gr.Dropdown(label="LoRA Repo", choices=[], info="Input LoRA Repo ID", value="", allow_custom_value=True)
lora_weights = gr.Dropdown(label="LoRA Filename", choices=[], info="Optional", value="", allow_custom_value=True)
lora_trigger = gr.Textbox(label="LoRA Trigger Prompt", value="")
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
with gr.Row():
model_name = gr.Dropdown(label="Base Model", choices=models, value=models[0], allow_custom_value=True)
gallery.select(
update_selection,
inputs=[width, height],
outputs=[prompt, selected_info, selected_index, width, height]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height,
lora_scale, lora_repo, lora_weights, lora_trigger],
outputs=[result, seed]
)
lora_repo.change(get_repo_safetensors, [lora_repo], [lora_weights])
model_name.change(change_base_model, [model_name], None)
app.queue()
app.launch()