Spaces:
Runtime error
Runtime error
File size: 4,125 Bytes
0e0ee20 7dc34c1 7039ded 0e0ee20 f3e96f9 c59400c 0e0ee20 d6802e8 f3e96f9 0e0ee20 701c4ad c59400c 0e0ee20 f3e96f9 0e0ee20 f3e96f9 0e0ee20 fd8e800 0e0ee20 fd8e800 0e0ee20 07d3eff 02302e4 07d3eff 0e0ee20 d6802e8 02302e4 07d3eff 0e0ee20 db98dea 1fff27d 0e0ee20 db98dea 0e0ee20 1fff27d db98dea 8dce9c7 0e0ee20 2c6d128 0e0ee20 07d3eff 0e0ee20 f3e96f9 0e0ee20 d6802e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline
import copy
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pipe.to("cuda")
MAX_SEED = 2**32-1
def update_selection(evt: gr.SelectData):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index
)
@spaces.GPU(duration=90)
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
yield image
pipe.unload_lora_weights()
'''
#gen_column{height: 100%}
#gen_btn{height: 100%}
'''
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("# FLUX.1 LoRA the Explorer")
selected_index = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
with gr.Column(scale=1, elem_id="gen_column"):
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
with gr.Row():
with gr.Column(scale=3):
selected_info = gr.Markdown("")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3
)
with gr.Column(scale=4):
result = gr.Image(label="Generated Image")
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=30)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.85)
gallery.select(update_selection, outputs=[prompt, selected_info, selected_index])
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
outputs=[result]
)
app.queue()
app.launch() |