fffiloni commited on
Commit
ba844d5
1 Parent(s): 69ac06c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -236
app.py CHANGED
@@ -1,249 +1,182 @@
1
  import gradio as gr
2
- from time import sleep, time
3
  from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
4
- from huggingface_hub import hf_hub_download, CommitScheduler
5
- from safetensors.torch import load_file
6
- from share_btn import community_icon_html, loading_icon_html, share_js
7
- from uuid import uuid4
8
- from pathlib import Path
9
- from PIL import Image
10
  import torch
11
- import json
12
- import random
13
  import copy
14
- import gc
15
- import pickle
16
  import spaces
 
17
 
18
- lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
19
-
20
- IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
21
- IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
22
- IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"
23
-
24
- scheduler = CommitScheduler(
25
- repo_id="multimodalart/lora-fusing-preferences",
26
- repo_type="dataset",
27
- folder_path=IMAGE_DATASET_DIR,
28
- path_in_repo=IMAGE_DATASET_DIR.name,
29
- every=10
30
- )
31
-
32
- with open(lora_list, "r") as file:
33
- data = json.load(file)
34
- sdxl_loras = [
35
- {
36
- "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co/spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
37
- "title": item["title"],
38
- "repo": item["repo"],
39
- "trigger_word": item["trigger_word"],
40
- "weights": item["weights"],
41
- "is_compatible": item["is_compatible"],
42
- "is_pivotal": item.get("is_pivotal", False),
43
- "text_embedding_weights": item.get("text_embedding_weights", None),
44
- "is_nc": item.get("is_nc", False)
45
- }
46
- for item in data
47
- ]
48
-
49
- state_dicts = {}
50
-
51
- for item in sdxl_loras:
52
- saved_name = hf_hub_download(item["repo"], item["weights"])
53
-
54
- if not saved_name.endswith('.safetensors'):
55
- state_dict = torch.load(saved_name, map_location=torch.device('cpu'))
56
- else:
57
- state_dict = load_file(saved_name, device="cpu")
58
-
59
- state_dicts[item["repo"]] = {
60
- "saved_name": saved_name,
61
- "state_dict": state_dict
62
- }
63
-
64
- css = '''
65
- .gradio-container{max-width: 650px! important}
66
- #title{text-align:center;}
67
- #title h1{font-size: 250%}
68
- .selected_random img{object-fit: cover}
69
- .selected_random [data-testid="block-label"] span{display: none}
70
- .plus_column{align-self: center}
71
- .plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
72
- #prompt{padding: 0 0 1em 0}
73
- #prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
74
- #run_button{position: absolute;margin-top: 25.8px;right: 0;margin-right: 0.75em;border-bottom-left-radius: 0px;border-top-left-radius: 0px}
75
- .random_column{align-self: center; align-items: center;gap: 0.5em !important}
76
- #share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
77
- div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
78
- #share-btn-container:hover {background-color: #060606}
79
- #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
80
- #share-btn * {all: unset}
81
- #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
82
- #share-btn-container .wrap {display: none !important}
83
- #share-btn-container.hidden {display: none!important}
84
- #post_gen_info{margin-top: .5em}
85
- #thumbs_up_clicked{background:green}
86
- #thumbs_down_clicked{background:red}
87
- .title_lora a{color: var(--body-text-color) !important; opacity:0.6}
88
- #prompt_area .form{border:0}
89
- #reroll_button{position: absolute;right: 0;flex-grow: 1;min-width: 75px;padding: .1em}
90
- .pending .min {min-height: auto}
91
- '''
92
 
93
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
94
 
95
  @spaces.GPU
96
- def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1):
97
-
98
- repo_id_1 = shuffled_items[0]['repo']
99
- repo_id_2 = shuffled_items[1]['repo']
100
- print("Loading state dicts...")
101
- start_time = time()
102
- state_dict_1 = copy.deepcopy(state_dicts[repo_id_1]["state_dict"])
103
- state_dict_1 = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict_1.items() if torch.is_tensor(v)}
104
- state_dict_2 = copy.deepcopy(state_dicts[repo_id_2]["state_dict"])
105
- state_dict_2 = {k: v.to(device="cuda", dtype=torch.float16) for k,v in state_dict_2.items() if torch.is_tensor(v)}
106
- state_dict_time = time() - start_time
107
- print(f"State Dict time: {state_dict_time}")
108
- start_time = time()
109
- unet = copy.deepcopy(original_pipe.unet)
110
- text_encoder=copy.deepcopy(original_pipe.text_encoder)
111
- text_encoder_2=copy.deepcopy(original_pipe.text_encoder_2)
112
- pipe = StableDiffusionXLPipeline(vae=original_pipe.vae,
113
- text_encoder=text_encoder,
114
- text_encoder_2=text_encoder_2,
115
- scheduler=original_pipe.scheduler,
116
- tokenizer=original_pipe.tokenizer,
117
- tokenizer_2=original_pipe.tokenizer_2,
118
- unet=unet)
119
- pickle_time = time() - start_time
120
- print(f"copy time: {pickle_time}")
121
- pipe.to("cuda")
122
- start_time = time()
123
- print("Loading LoRA weights...")
124
- pipe.load_lora_weights(state_dict_1, low_cpu_mem_usage=True)
125
- pipe.fuse_lora(lora_1_scale)
126
- pipe.load_lora_weights(state_dict_2, low_cpu_mem_usage=True)
127
- pipe.fuse_lora(lora_2_scale)
128
- lora_time = time() - start_time
129
- print(f"Loaded LoRAs time: {lora_time}")
130
- if negative_prompt == "":
131
- negative_prompt = None
132
-
133
- if(seed < 0):
134
- seed = random.randint(0, 2147483647)
135
- generator = torch.Generator(device="cuda").manual_seed(seed)
136
- image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
137
- return image, gr.update(visible=True), seed, gr.update(visible=True, interactive=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=False)
138
-
139
- def get_description(item):
140
- trigger_word = item["trigger_word"]
141
- return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger, applied automatically", trigger_word
142
-
143
- def truncate_string(s, max_length=29):
144
- return s[:max_length - 3] + "..." if len(s) > max_length else s
145
-
146
- def shuffle_images():
147
- compatible_items = [item for item in sdxl_loras if item['is_compatible']]
148
- random.shuffle(compatible_items)
149
- two_shuffled_items = compatible_items[:2]
150
- title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
151
- title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
152
- repo_id_1 = gr.update(value=two_shuffled_items[0]['repo'])
153
- repo_id_2 = gr.update(value=two_shuffled_items[1]['repo'])
154
- description_1, trigger_word_1 = get_description(two_shuffled_items[0])
155
- description_2, trigger_word_2 = get_description(two_shuffled_items[1])
156
-
157
- lora_1_link = f"[{truncate_string(two_shuffled_items[0]['repo'])}](https://huggingface.co/{two_shuffled_items[0]['repo']}) ✨"
158
- lora_2_link = f"[{truncate_string(two_shuffled_items[1]['repo'])}](https://huggingface.co/{two_shuffled_items[1]['repo']}) ✨"
159
- prompt_description_1 = gr.update(value=description_1, visible=True)
160
- prompt_description_2 = gr.update(value=description_2, visible=True)
161
- prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
162
- scale = gr.update(value=0.7)
163
-
164
- return lora_1_link, title_1, prompt_description_1, repo_id_1, lora_2_link, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale
165
-
166
- def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, generated_image, thumbs_direction, seed):
167
- image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
168
- with scheduler.lock:
169
- Image.fromarray(generated_image).save(image_path)
170
- with IMAGE_JSONL_PATH.open("a") as f:
171
- json.dump({"prompt": prompt, "file_name":image_path.name, "lora_1_id": lora_1_id, "lora_1_scale": float(lora_1_scale), "lora_2_id": lora_2_id, "lora_2_scale": float(lora_2_scale), "thumbs_direction": thumbs_direction, "seed": int(seed)}, f)
172
- f.write("\n")
173
 
174
- return gr.update(visible=False), gr.update(visible=True), gr.update(interactive=False)
175
-
176
- def hide_post_gen_info():
177
- return gr.update(visible=False)
178
-
179
- with gr.Blocks(css=css) as demo:
180
- shuffled_items = gr.State()
181
- title = gr.HTML(
182
- '''<h1>LoRA Roulette 🎰</h1>
183
- <p>This random LoRAs are loaded into SDXL, can you find a fun way to combine them? 🎨</p>
184
- ''',
185
- elem_id="title"
186
- )
187
- with gr.Column():
188
- with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
 
 
 
 
 
 
 
 
 
 
 
 
189
  with gr.Row():
190
- with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
191
- lora_1_link = gr.Markdown(elem_classes="title_lora")
192
- lora_1 = gr.Image(interactive=False, height=150, elem_classes="selected_random", elem_id="randomLoRA_1", show_share_button=False, show_download_button=False)
193
- lora_1_id = gr.Textbox(visible=False, elem_id="random_lora_1_id")
194
- lora_1_prompt = gr.Markdown(visible=False)
195
- with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
196
- plus = gr.HTML("+", elem_classes="plus_button")
197
- with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
198
- lora_2_link = gr.Markdown(elem_classes="title_lora")
199
- lora_2 = gr.Image(interactive=False, height=150, elem_classes="selected_random", elem_id="randomLoRA_2", show_share_button=False, show_download_button=False)
200
- lora_2_id = gr.Textbox(visible=False, elem_id="random_lora_2_id")
201
- lora_2_prompt = gr.Markdown(visible=False)
202
- with gr.Column(min_width=10, scale=2, elem_classes="plus_column"):
203
- equal = gr.HTML("=", elem_classes="plus_button")
204
- shuffle_button = gr.Button("🎲 reroll", elem_id="reroll_button")
205
- with gr.Column(min_width=10, scale=14):
206
- with gr.Box(elem_id="generate_area"):
207
- with gr.Row(elem_id="prompt_area"):
208
- prompt = gr.Textbox(label="Your prompt", info="Rearrange the trigger words into a coherent prompt", show_label=False, interactive=True, elem_id="prompt")
209
- run_btn = gr.Button("Run", elem_id="run_button")
210
- output_image = gr.Image(label="Output", height=460, elem_id="output_image", interactive=False)
211
- with gr.Column(visible=False, elem_id="post_gen_info") as post_gen_info:
212
- with gr.Row():
213
- with gr.Column(min_width=10):
214
- thumbs_up = gr.Button("👍", elem_id="thumbs_up_unclicked")
215
- thumbs_up_clicked = gr.Button("👍", elem_id="thumbs_up_clicked", interactive=False, visible=False)
216
- with gr.Column(min_width=10):
217
- thumbs_down = gr.Button("👎", elem_id="thumbs_down_unclicked")
218
- thumbs_down_clicked = gr.Button("👎", elem_id="thumbs_down_clicked", interactive=False, visible=False)
219
- with gr.Column(min_width=10):
220
- with gr.Group(elem_id="share-btn-container") as share_group:
221
- community_icon = gr.HTML(community_icon_html)
222
- loading_icon = gr.HTML(loading_icon_html)
223
- share_button = gr.Button("Share to community", elem_id="share-btn")
224
- gr.Markdown('<p style="font-size: 90%;opacity: 0.8;">Rating helps finding the most compatible LoRAs 🤗, results are shared annonymously <a href="https://huggingface.co/datasets/multimodalart/lora-fusing-preferences" target="_blank" rel="noopener noreferrer">here</a></p>')
225
- with gr.Accordion("Advanced settings", open=False):
226
- with gr.Row():
227
- lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
228
- lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
229
- negative_prompt = gr.Textbox(label="Negative prompt")
230
- seed = gr.Slider(label="Seed", info="-1 denotes a random seed", minimum=-1, maximum=2147483647, value=-1)
231
- last_used_seed = gr.Number(label="Last used seed", info="The seed used in the last generation", minimum=0, maximum=2147483647, value=-1, interactive=False)
232
- gr.Markdown("Generate with intent in [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), but remember: sometimes restrictions flourish creativity 🌸")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- demo.load(shuffle_images, inputs=[], outputs=[lora_1_link, lora_1, lora_1_prompt, lora_1_id, lora_2_link, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
235
- shuffle_button.click(shuffle_images, outputs=[lora_1_link, lora_1, lora_1_prompt, lora_1_id, lora_2_link, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
236
-
237
- run_btn.click(hide_post_gen_info, outputs=[post_gen_info], queue=False).then(merge_and_run,
238
- inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale, seed],
239
- outputs=[output_image, post_gen_info, last_used_seed, thumbs_up, thumbs_up_clicked, thumbs_down, thumbs_down_clicked])
240
- prompt.submit(hide_post_gen_info, outputs=[post_gen_info], queue=False).then(merge_and_run,
241
- inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale, seed],
242
- outputs=[output_image, post_gen_info, last_used_seed, thumbs_up, thumbs_up_clicked, thumbs_down, thumbs_down_clicked])
243
-
244
- thumbs_up.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("up"), last_used_seed], outputs=[thumbs_up, thumbs_up_clicked, thumbs_down])
245
- thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), last_used_seed], outputs=[thumbs_down, thumbs_down_clicked, thumbs_up])
246
- share_button.click(None, [], [], _js=share_js)
247
-
248
- demo.queue(concurrency_count=2)
249
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import login
3
  from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
 
 
 
 
 
 
4
  import torch
 
 
5
  import copy
6
+ import os
 
7
  import spaces
8
+ import random
9
 
10
+ hf_token = os.environ.get("HF_TOKEN")
11
+ login(token = hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
14
 
15
  @spaces.GPU
16
+ def infer(lora_1_id, lora_1_sfts, lora_2_id, lora_2_sfts, prompt, negative_prompt, lora_1_scale, lora_2_scale, seed):
17
+
18
+ unet = copy.deepcopy(original_pipe.unet)
19
+ text_encoder = copy.deepcopy(original_pipe.text_encoder)
20
+ text_encoder_2 = copy.deepcopy(original_pipe.text_encoder_2)
21
+
22
+ pipe = StableDiffusionXLPipeline(
23
+ vae = original_pipe.vae,
24
+ text_encoder = text_encoder,
25
+ text_encoder_2 = text_encoder_2,
26
+ scheduler = original_pipe.scheduler,
27
+ tokenizer = original_pipe.tokenizer,
28
+ tokenizer_2 = original_pipe.tokenizer_2,
29
+ unet = unet
30
+ )
31
+
32
+ pipe.to("cuda")
33
+
34
+ pipe.load_lora_weights(
35
+ lora_1_id,
36
+ weight_name = lora_1_sfts,
37
+ low_cpu_mem_usage = True,
38
+ use_auth_token = True
39
+ )
40
+
41
+ pipe.fuse_lora(lora_1_scale)
42
+
43
+ pipe.load_lora_weights(
44
+ lora_2_id,
45
+ weight_name = lora_2_sfts,
46
+ low_cpu_mem_usage = True,
47
+ use_auth_token = True
48
+ )
49
+
50
+ pipe.fuse_lora(lora_2_scale)
51
+
52
+ if negative_prompt == "" :
53
+ negative_prompt = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ if seed < 0 :
56
+ seed = random.randit(0, 423538377342)
57
+
58
+ generator = torch.Generator(device="cuda").manual_seed(seed)
59
+
60
+ image = pipe(
61
+ prompt = prompt,
62
+ negative_prompt = negative_prompt,
63
+ num_inference_steps = 25,
64
+ width = 1024,
65
+ height = 1024,
66
+ generator = generator
67
+ ).images[0]
68
+
69
+ return image, seed
70
+
71
+ with gr.Blocks() as demo:
72
+ with gr.Column(elem_id="col-container"):
73
+
74
+ title = gr.HTML(
75
+ '''
76
+ <h1 style="text-align: center;">LoRA Fusion</h1>
77
+ <p style="text-align: center;">Fuse 2 custom LoRa models</p>
78
+ '''
79
+ )
80
+
81
+ # PART 1 • MODELS
82
  with gr.Row():
83
+
84
+ with gr.Column():
85
+
86
+ lora_1_id = gr.Textbox(
87
+ label = "LoRa 1 ID",
88
+ placeholder = "username/model_id"
89
+ )
90
+
91
+ lora_1_sfts = gr.Textbox(
92
+ label = "Safetensors file",
93
+ placeholder = "specific_chosen.safetensors"
94
+ )
95
+
96
+ with gr.Column():
97
+
98
+ lora_2_id = gr.Textbox(
99
+ label = "LoRa 2 ID",
100
+ placeholder = "username/model_id"
101
+ )
102
+
103
+ lora_2_sfts = gr.Textbox(
104
+ label = "Safetensors file",
105
+ placeholder = "specific_chosen.safetensors"
106
+ )
107
+
108
+ # PART 2 INFERENCE
109
+ with gr.Row():
110
+
111
+ prompt = gr.Textbox(
112
+ label = "Your prompt",
113
+ info = "Use your trigger words into a coherent prompt"
114
+ placeholde = "e.g: a triggerWordOne portrait in triggerWord2 style"
115
+ )
116
+
117
+ run_btn = gr.Button("Run")
118
+
119
+ output_image = gr.Image(
120
+ label = "Output"
121
+ )
122
+
123
+ # Advanced Settings
124
+ with gr.Accordion("Advanced Settings", open=False):
125
+
126
+ with gr.Row():
127
+
128
+ lora_1_scale = gr.Slider(
129
+ label = "LoRa 1 scale",
130
+ minimum = 0,
131
+ maximum = 1,
132
+ steps = 0.1,
133
+ value = 0.7
134
+ )
135
+
136
+ lora_2_scale = gr.Slider(
137
+ label = "LoRa 2 scale",
138
+ minimum = 0,
139
+ maximum = 1,
140
+ steps = 0.1,
141
+ value = 0.7
142
+ )
143
+
144
+ negative_prompt = gr.Textbox(
145
+ label = "Negative prompt"
146
+ )
147
+
148
+ seed = gr.Slider(
149
+ label = "Seed",
150
+ info = "-1 denotes a random seed",
151
+ minimum = -1,
152
+ maximum = 423538377342,
153
+ value = -1
154
+ )
155
+
156
+ last_used_seed = gr.Number(
157
+ label = Last used seed,
158
+ info = "the seed used in the last generation",
159
+ )
160
 
161
+ # ACTIONS
162
+ run_btn.click(
163
+ fn = infer,
164
+ inputs = [
165
+ lora_1_id,
166
+ lora_1_sfts,
167
+ lora_2_id,
168
+ lora_2_sfts,
169
+ prompt,
170
+ negative_prompt,
171
+ lora_1_scale,
172
+ lora_2_scale,
173
+ seed
174
+ ],
175
+ outputs = [
176
+ output_image,
177
+ last_used_seed
178
+ ]
179
+ )
180
+
181
+ demo.queue(concurrency_count=2).launch()
182
+