Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,824 Bytes
05a3ba6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import spaces # Import this first to avoid CUDA initialization issues
import os
import gradio as gr
import json
import torch
import random
import time
from PIL import Image
from diffusers import DiffusionPipeline
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the 'waffles' environment variable as the access token
hf_token = os.getenv('waffles')
# Ensure the token is loaded correctly
if not hf_token:
raise ValueError("Hugging Face API token not found. Please set the 'waffles' environment variable.")
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model with authentication and specify the device
# Initialize the base model with authentication and specify the device
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
token=hf_token
).to(device)
MAX_SEED = 2**32 - 1
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
@spaces.GPU(duration=90)
def generate_images(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, num_images, progress):
generator = torch.Generator(device=device).manual_seed(seed)
images = []
with calculateDuration("Generating images"):
for _ in range(num_images):
# Generate each image
image = pipe(
prompt=f"{prompt} {trigger_word}",
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
images.append(image)
return images
def run_lora(prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale, num_images, progress=gr.Progress(track_tqdm=True)):
if not selected_repo:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = next((lora for lora in loras if lora["repo"] == selected_repo), None)
if not selected_lora:
raise gr.Error("Selected LoRA not found.")
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Load LoRA weights
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
if "weights" in selected_lora:
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
else:
pipe.load_lora_weights(lora_path)
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
images = generate_images(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, num_images, progress)
pipe.to("cpu")
pipe.unload_lora_weights()
return images, seed
def update_selection(evt: gr.SelectData):
index = evt.index
selected_lora = loras[index]
return f"Selected LoRA: {selected_lora['title']}", selected_lora["repo"]
run_lora.zerogpu = True
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: auto; width: auto;}
#gallery .gallery-item{width: 50px; height: 50px; margin: 0px;} /* Make buttons 50% height and width */
#gallery img{width: 100%; height: 100%; object-fit: cover;} /* Resize images to fit buttons */
#info_blob {
background-color: #f0f0f0;
border: 2px solid #ccc;
padding: 10px;
margin: 10px 0;
text-align: center;
font-size: 1.2em;
font-weight: bold;
color: #333;
border-radius: 8px;
}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
title = gr.HTML(
"""<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
elem_id="title",
)
# Info blob stating what the app is running
info_blob = gr.HTML(
"""<div id="info_blob"> Activist, Futurist, and Realist LoRa-stocked Quick-Use Image Manufactory (over Flux Schnell)</div>"""
)
selected_lora_text = gr.Markdown("Selected LoRA: None")
selected_repo = gr.State(value="")
# Prompt takes the full line
prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Type a prompt after selecting a LoRA", elem_id="full_line_prompt")
with gr.Row():
with gr.Column(scale=1): # LoRA collection on the left
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="LoRA Gallery",
allow_preview=False,
columns=3,
elem_id="gallery"
)
with gr.Column(scale=1): # Generated images on the right
result = gr.Gallery(label="Generated Images")
seed = gr.Number(label="Seed", value=0, interactive=False)
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=1)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=4)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)
gallery.select(
fn=update_selection,
inputs=[],
outputs=[selected_lora_text, selected_repo]
)
generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
generate_button.click(
run_lora,
inputs=[prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale, num_images],
outputs=[result, seed]
)
app.queue()
app.launch() |