Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,249 +1,182 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
4 |
-
from huggingface_hub import hf_hub_download, CommitScheduler
|
5 |
-
from safetensors.torch import load_file
|
6 |
-
from share_btn import community_icon_html, loading_icon_html, share_js
|
7 |
-
from uuid import uuid4
|
8 |
-
from pathlib import Path
|
9 |
-
from PIL import Image
|
10 |
import torch
|
11 |
-
import json
|
12 |
-
import random
|
13 |
import copy
|
14 |
-
import
|
15 |
-
import pickle
|
16 |
import spaces
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
|
21 |
-
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
22 |
-
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"
|
23 |
-
|
24 |
-
scheduler = CommitScheduler(
|
25 |
-
repo_id="multimodalart/lora-fusing-preferences",
|
26 |
-
repo_type="dataset",
|
27 |
-
folder_path=IMAGE_DATASET_DIR,
|
28 |
-
path_in_repo=IMAGE_DATASET_DIR.name,
|
29 |
-
every=10
|
30 |
-
)
|
31 |
-
|
32 |
-
with open(lora_list, "r") as file:
|
33 |
-
data = json.load(file)
|
34 |
-
sdxl_loras = [
|
35 |
-
{
|
36 |
-
"image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co/spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
|
37 |
-
"title": item["title"],
|
38 |
-
"repo": item["repo"],
|
39 |
-
"trigger_word": item["trigger_word"],
|
40 |
-
"weights": item["weights"],
|
41 |
-
"is_compatible": item["is_compatible"],
|
42 |
-
"is_pivotal": item.get("is_pivotal", False),
|
43 |
-
"text_embedding_weights": item.get("text_embedding_weights", None),
|
44 |
-
"is_nc": item.get("is_nc", False)
|
45 |
-
}
|
46 |
-
for item in data
|
47 |
-
]
|
48 |
-
|
49 |
-
state_dicts = {}
|
50 |
-
|
51 |
-
for item in sdxl_loras:
|
52 |
-
saved_name = hf_hub_download(item["repo"], item["weights"])
|
53 |
-
|
54 |
-
if not saved_name.endswith('.safetensors'):
|
55 |
-
state_dict = torch.load(saved_name, map_location=torch.device('cpu'))
|
56 |
-
else:
|
57 |
-
state_dict = load_file(saved_name, device="cpu")
|
58 |
-
|
59 |
-
state_dicts[item["repo"]] = {
|
60 |
-
"saved_name": saved_name,
|
61 |
-
"state_dict": state_dict
|
62 |
-
}
|
63 |
-
|
64 |
-
css = '''
|
65 |
-
.gradio-container{max-width: 650px! important}
|
66 |
-
#title{text-align:center;}
|
67 |
-
#title h1{font-size: 250%}
|
68 |
-
.selected_random img{object-fit: cover}
|
69 |
-
.selected_random [data-testid="block-label"] span{display: none}
|
70 |
-
.plus_column{align-self: center}
|
71 |
-
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
|
72 |
-
#prompt{padding: 0 0 1em 0}
|
73 |
-
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
|
74 |
-
#run_button{position: absolute;margin-top: 25.8px;right: 0;margin-right: 0.75em;border-bottom-left-radius: 0px;border-top-left-radius: 0px}
|
75 |
-
.random_column{align-self: center; align-items: center;gap: 0.5em !important}
|
76 |
-
#share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
|
77 |
-
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
|
78 |
-
#share-btn-container:hover {background-color: #060606}
|
79 |
-
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
|
80 |
-
#share-btn * {all: unset}
|
81 |
-
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
|
82 |
-
#share-btn-container .wrap {display: none !important}
|
83 |
-
#share-btn-container.hidden {display: none!important}
|
84 |
-
#post_gen_info{margin-top: .5em}
|
85 |
-
#thumbs_up_clicked{background:green}
|
86 |
-
#thumbs_down_clicked{background:red}
|
87 |
-
.title_lora a{color: var(--body-text-color) !important; opacity:0.6}
|
88 |
-
#prompt_area .form{border:0}
|
89 |
-
#reroll_button{position: absolute;right: 0;flex-grow: 1;min-width: 75px;padding: .1em}
|
90 |
-
.pending .min {min-height: auto}
|
91 |
-
'''
|
92 |
|
93 |
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
94 |
|
95 |
@spaces.GPU
|
96 |
-
def
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
seed = random.randint(0, 2147483647)
|
135 |
-
generator = torch.Generator(device="cuda").manual_seed(seed)
|
136 |
-
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
|
137 |
-
return image, gr.update(visible=True), seed, gr.update(visible=True, interactive=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=False)
|
138 |
-
|
139 |
-
def get_description(item):
|
140 |
-
trigger_word = item["trigger_word"]
|
141 |
-
return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger, applied automatically", trigger_word
|
142 |
-
|
143 |
-
def truncate_string(s, max_length=29):
|
144 |
-
return s[:max_length - 3] + "..." if len(s) > max_length else s
|
145 |
-
|
146 |
-
def shuffle_images():
|
147 |
-
compatible_items = [item for item in sdxl_loras if item['is_compatible']]
|
148 |
-
random.shuffle(compatible_items)
|
149 |
-
two_shuffled_items = compatible_items[:2]
|
150 |
-
title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
|
151 |
-
title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
|
152 |
-
repo_id_1 = gr.update(value=two_shuffled_items[0]['repo'])
|
153 |
-
repo_id_2 = gr.update(value=two_shuffled_items[1]['repo'])
|
154 |
-
description_1, trigger_word_1 = get_description(two_shuffled_items[0])
|
155 |
-
description_2, trigger_word_2 = get_description(two_shuffled_items[1])
|
156 |
-
|
157 |
-
lora_1_link = f"[{truncate_string(two_shuffled_items[0]['repo'])}](https://huggingface.co/{two_shuffled_items[0]['repo']}) ✨"
|
158 |
-
lora_2_link = f"[{truncate_string(two_shuffled_items[1]['repo'])}](https://huggingface.co/{two_shuffled_items[1]['repo']}) ✨"
|
159 |
-
prompt_description_1 = gr.update(value=description_1, visible=True)
|
160 |
-
prompt_description_2 = gr.update(value=description_2, visible=True)
|
161 |
-
prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
|
162 |
-
scale = gr.update(value=0.7)
|
163 |
-
|
164 |
-
return lora_1_link, title_1, prompt_description_1, repo_id_1, lora_2_link, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale
|
165 |
-
|
166 |
-
def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, generated_image, thumbs_direction, seed):
|
167 |
-
image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
|
168 |
-
with scheduler.lock:
|
169 |
-
Image.fromarray(generated_image).save(image_path)
|
170 |
-
with IMAGE_JSONL_PATH.open("a") as f:
|
171 |
-
json.dump({"prompt": prompt, "file_name":image_path.name, "lora_1_id": lora_1_id, "lora_1_scale": float(lora_1_scale), "lora_2_id": lora_2_id, "lora_2_scale": float(lora_2_scale), "thumbs_direction": thumbs_direction, "seed": int(seed)}, f)
|
172 |
-
f.write("\n")
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
with gr.Row():
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from huggingface_hub import login
|
3 |
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
|
|
|
|
5 |
import copy
|
6 |
+
import os
|
|
|
7 |
import spaces
|
8 |
+
import random
|
9 |
|
10 |
+
hf_token = os.environ.get("HF_TOKEN")
|
11 |
+
login(token = hf_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
14 |
|
15 |
@spaces.GPU
|
16 |
+
def infer(lora_1_id, lora_1_sfts, lora_2_id, lora_2_sfts, prompt, negative_prompt, lora_1_scale, lora_2_scale, seed):
|
17 |
+
|
18 |
+
unet = copy.deepcopy(original_pipe.unet)
|
19 |
+
text_encoder = copy.deepcopy(original_pipe.text_encoder)
|
20 |
+
text_encoder_2 = copy.deepcopy(original_pipe.text_encoder_2)
|
21 |
+
|
22 |
+
pipe = StableDiffusionXLPipeline(
|
23 |
+
vae = original_pipe.vae,
|
24 |
+
text_encoder = text_encoder,
|
25 |
+
text_encoder_2 = text_encoder_2,
|
26 |
+
scheduler = original_pipe.scheduler,
|
27 |
+
tokenizer = original_pipe.tokenizer,
|
28 |
+
tokenizer_2 = original_pipe.tokenizer_2,
|
29 |
+
unet = unet
|
30 |
+
)
|
31 |
+
|
32 |
+
pipe.to("cuda")
|
33 |
+
|
34 |
+
pipe.load_lora_weights(
|
35 |
+
lora_1_id,
|
36 |
+
weight_name = lora_1_sfts,
|
37 |
+
low_cpu_mem_usage = True,
|
38 |
+
use_auth_token = True
|
39 |
+
)
|
40 |
+
|
41 |
+
pipe.fuse_lora(lora_1_scale)
|
42 |
+
|
43 |
+
pipe.load_lora_weights(
|
44 |
+
lora_2_id,
|
45 |
+
weight_name = lora_2_sfts,
|
46 |
+
low_cpu_mem_usage = True,
|
47 |
+
use_auth_token = True
|
48 |
+
)
|
49 |
+
|
50 |
+
pipe.fuse_lora(lora_2_scale)
|
51 |
+
|
52 |
+
if negative_prompt == "" :
|
53 |
+
negative_prompt = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
if seed < 0 :
|
56 |
+
seed = random.randit(0, 423538377342)
|
57 |
+
|
58 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
59 |
+
|
60 |
+
image = pipe(
|
61 |
+
prompt = prompt,
|
62 |
+
negative_prompt = negative_prompt,
|
63 |
+
num_inference_steps = 25,
|
64 |
+
width = 1024,
|
65 |
+
height = 1024,
|
66 |
+
generator = generator
|
67 |
+
).images[0]
|
68 |
+
|
69 |
+
return image, seed
|
70 |
+
|
71 |
+
with gr.Blocks() as demo:
|
72 |
+
with gr.Column(elem_id="col-container"):
|
73 |
+
|
74 |
+
title = gr.HTML(
|
75 |
+
'''
|
76 |
+
<h1 style="text-align: center;">LoRA Fusion</h1>
|
77 |
+
<p style="text-align: center;">Fuse 2 custom LoRa models</p>
|
78 |
+
'''
|
79 |
+
)
|
80 |
+
|
81 |
+
# PART 1 • MODELS
|
82 |
with gr.Row():
|
83 |
+
|
84 |
+
with gr.Column():
|
85 |
+
|
86 |
+
lora_1_id = gr.Textbox(
|
87 |
+
label = "LoRa 1 ID",
|
88 |
+
placeholder = "username/model_id"
|
89 |
+
)
|
90 |
+
|
91 |
+
lora_1_sfts = gr.Textbox(
|
92 |
+
label = "Safetensors file",
|
93 |
+
placeholder = "specific_chosen.safetensors"
|
94 |
+
)
|
95 |
+
|
96 |
+
with gr.Column():
|
97 |
+
|
98 |
+
lora_2_id = gr.Textbox(
|
99 |
+
label = "LoRa 2 ID",
|
100 |
+
placeholder = "username/model_id"
|
101 |
+
)
|
102 |
+
|
103 |
+
lora_2_sfts = gr.Textbox(
|
104 |
+
label = "Safetensors file",
|
105 |
+
placeholder = "specific_chosen.safetensors"
|
106 |
+
)
|
107 |
+
|
108 |
+
# PART 2 • INFERENCE
|
109 |
+
with gr.Row():
|
110 |
+
|
111 |
+
prompt = gr.Textbox(
|
112 |
+
label = "Your prompt",
|
113 |
+
info = "Use your trigger words into a coherent prompt"
|
114 |
+
placeholde = "e.g: a triggerWordOne portrait in triggerWord2 style"
|
115 |
+
)
|
116 |
+
|
117 |
+
run_btn = gr.Button("Run")
|
118 |
+
|
119 |
+
output_image = gr.Image(
|
120 |
+
label = "Output"
|
121 |
+
)
|
122 |
+
|
123 |
+
# Advanced Settings
|
124 |
+
with gr.Accordion("Advanced Settings", open=False):
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
|
128 |
+
lora_1_scale = gr.Slider(
|
129 |
+
label = "LoRa 1 scale",
|
130 |
+
minimum = 0,
|
131 |
+
maximum = 1,
|
132 |
+
steps = 0.1,
|
133 |
+
value = 0.7
|
134 |
+
)
|
135 |
+
|
136 |
+
lora_2_scale = gr.Slider(
|
137 |
+
label = "LoRa 2 scale",
|
138 |
+
minimum = 0,
|
139 |
+
maximum = 1,
|
140 |
+
steps = 0.1,
|
141 |
+
value = 0.7
|
142 |
+
)
|
143 |
+
|
144 |
+
negative_prompt = gr.Textbox(
|
145 |
+
label = "Negative prompt"
|
146 |
+
)
|
147 |
+
|
148 |
+
seed = gr.Slider(
|
149 |
+
label = "Seed",
|
150 |
+
info = "-1 denotes a random seed",
|
151 |
+
minimum = -1,
|
152 |
+
maximum = 423538377342,
|
153 |
+
value = -1
|
154 |
+
)
|
155 |
+
|
156 |
+
last_used_seed = gr.Number(
|
157 |
+
label = Last used seed,
|
158 |
+
info = "the seed used in the last generation",
|
159 |
+
)
|
160 |
|
161 |
+
# ACTIONS
|
162 |
+
run_btn.click(
|
163 |
+
fn = infer,
|
164 |
+
inputs = [
|
165 |
+
lora_1_id,
|
166 |
+
lora_1_sfts,
|
167 |
+
lora_2_id,
|
168 |
+
lora_2_sfts,
|
169 |
+
prompt,
|
170 |
+
negative_prompt,
|
171 |
+
lora_1_scale,
|
172 |
+
lora_2_scale,
|
173 |
+
seed
|
174 |
+
],
|
175 |
+
outputs = [
|
176 |
+
output_image,
|
177 |
+
last_used_seed
|
178 |
+
]
|
179 |
+
)
|
180 |
+
|
181 |
+
demo.queue(concurrency_count=2).launch()
|
182 |
+
|