File size: 6,824 Bytes
05a3ba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import spaces  # Import this first to avoid CUDA initialization issues
import os
import gradio as gr
import json
import torch
import random
import time
from PIL import Image
from diffusers import DiffusionPipeline

# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Use the 'waffles' environment variable as the access token
hf_token = os.getenv('waffles')

# Ensure the token is loaded correctly
if not hf_token:
    raise ValueError("Hugging Face API token not found. Please set the 'waffles' environment variable.")

# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Initialize the base model with authentication and specify the device
# Initialize the base model with authentication and specify the device
pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
    token=hf_token
).to(device)

MAX_SEED = 2**32 - 1

class calculateDuration:
    def __init__(self, activity_name=""):
        self.activity_name = activity_name

    def __enter__(self):
        self.start_time = time.time()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        self.elapsed_time = self.end_time - self.start_time
        if self.activity_name:
            print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
        else:
            print(f"Elapsed time: {self.elapsed_time:.6f} seconds")

@spaces.GPU(duration=90)
def generate_images(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, num_images, progress):
    generator = torch.Generator(device=device).manual_seed(seed)
    images = []
    
    with calculateDuration("Generating images"):
        for _ in range(num_images):
            # Generate each image
            image = pipe(
                prompt=f"{prompt} {trigger_word}",
                num_inference_steps=steps,
                guidance_scale=cfg_scale,
                width=width,
                height=height,
                generator=generator,
                joint_attention_kwargs={"scale": lora_scale},
            ).images[0]
            images.append(image)
    return images

def run_lora(prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale, num_images, progress=gr.Progress(track_tqdm=True)):
    if not selected_repo:
        raise gr.Error("You must select a LoRA before proceeding.")

    selected_lora = next((lora for lora in loras if lora["repo"] == selected_repo), None)
    if not selected_lora:
        raise gr.Error("Selected LoRA not found.")

    lora_path = selected_lora["repo"]
    trigger_word = selected_lora["trigger_word"]

    # Load LoRA weights
    with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
        if "weights" in selected_lora:
            pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
        else:
            pipe.load_lora_weights(lora_path)
        
    # Set random seed for reproducibility
    with calculateDuration("Randomizing seed"):
        if randomize_seed:
            seed = random.randint(0, MAX_SEED)
    
    images = generate_images(prompt, trigger_word, steps, seed, cfg_scale, width, height, lora_scale, num_images, progress)
    pipe.to("cpu")
    pipe.unload_lora_weights()
    return images, seed  

def update_selection(evt: gr.SelectData):
    index = evt.index
    selected_lora = loras[index]
    return f"Selected LoRA: {selected_lora['title']}", selected_lora["repo"]

run_lora.zerogpu = True

css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: auto; width: auto;}
#gallery .gallery-item{width: 50px; height: 50px; margin: 0px;} /* Make buttons 50% height and width */
#gallery img{width: 100%; height: 100%; object-fit: cover;} /* Resize images to fit buttons */
#info_blob {
    background-color: #f0f0f0;
    border: 2px solid #ccc;
    padding: 10px;
    margin: 10px 0;
    text-align: center;
    font-size: 1.2em;
    font-weight: bold;
    color: #333;
    border-radius: 8px;
}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
    title = gr.HTML(
        """<h1><img src="https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/resolve/main/flux_lora.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
        elem_id="title",
    )

    # Info blob stating what the app is running
    info_blob = gr.HTML(
        """<div id="info_blob"> Activist, Futurist, and Realist LoRa-stocked Quick-Use Image Manufactory (over Flux Schnell)</div>"""
    )

    selected_lora_text = gr.Markdown("Selected LoRA: None")
    selected_repo = gr.State(value="")

    # Prompt takes the full line
    prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Type a prompt after selecting a LoRA", elem_id="full_line_prompt")

    with gr.Row():
        with gr.Column(scale=1):  # LoRA collection on the left
            gallery = gr.Gallery(
                [(item["image"], item["title"]) for item in loras],
                label="LoRA Gallery",
                allow_preview=False,
                columns=3,
                elem_id="gallery"
            )
        with gr.Column(scale=1):  # Generated images on the right
            result = gr.Gallery(label="Generated Images")
            seed = gr.Number(label="Seed", value=0, interactive=False)

    with gr.Column():
        with gr.Row():
            cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=1)
            steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=4)
        
        with gr.Row():
            width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
            height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
        
        with gr.Row():
            randomize_seed = gr.Checkbox(True, label="Randomize seed")
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)
            num_images = gr.Slider(label="Number of Images", minimum=1, maximum=4, step=1, value=1)

    gallery.select(
        fn=update_selection,
        inputs=[],
        outputs=[selected_lora_text, selected_repo]
    )

    generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
    generate_button.click(
        run_lora,
        inputs=[prompt, cfg_scale, steps, selected_repo, randomize_seed, seed, width, height, lora_scale, num_images],
        outputs=[result, seed]
    )

app.queue()
app.launch()