alvdansen commited on
Commit
974dfa6
β€’
1 Parent(s): aef31f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -130
app.py CHANGED
@@ -1,31 +1,17 @@
1
  import json
2
  import random
3
-
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
 
9
 
10
  # Load the JSON data
11
  with open("sdxl_lora.json", "r") as file:
12
  data = json.load(file)
13
- sdxl_loras_raw = [
14
- {
15
- "image": item["image"],
16
- "title": item["title"],
17
- "repo": item["repo"],
18
- "trigger_word": item["trigger_word"],
19
- "weights": item["weights"],
20
- "is_pivotal": item.get("is_pivotal", False),
21
- "text_embedding_weights": item.get("text_embedding_weights", None),
22
- "likes": item.get("likes", 0),
23
- }
24
- for item in data
25
- ]
26
-
27
- # Sort the loras by likes
28
- sdxl_loras_raw = sorted(sdxl_loras_raw, key=lambda x: x["likes"], reverse=True)
29
 
30
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -43,9 +29,33 @@ def update_selection(selected_state: gr.SelectData, gr_sdxl_loras):
43
  return lora_id, trigger_word
44
 
45
  def load_lora_for_style(style_repo):
46
- pipe.unload_lora_weights() # Unload any previously loaded weights
47
  pipe.load_lora_weights(style_repo, adapter_name="lora")
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  @spaces.GPU
50
  def infer(
51
  pre_prompt,
@@ -59,7 +69,6 @@ def infer(
59
  user_lora_weight,
60
  progress=gr.Progress(track_tqdm=True),
61
  ):
62
- # Load the appropriate LoRA weights
63
  load_lora_for_style(user_lora_selector)
64
 
65
  if randomize_seed:
@@ -81,143 +90,131 @@ def infer(
81
  return image
82
 
83
  css = """
84
- h1 {
 
 
 
 
 
 
 
 
 
 
85
  text-align: center;
86
- display:block;
87
  }
88
- p {
89
- text-align: justify;
90
- display:block;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  }
92
  """
93
 
94
- if torch.cuda.is_available():
95
- power_device = "GPU"
96
- else:
97
- power_device = "CPU"
98
-
99
  with gr.Blocks(css=css) as demo:
100
  gr.Markdown(
101
- f"""
102
- # ⚑ FlashDiffusion: FlashLoRA ⚑
103
- This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) **on top of** existing LoRAs.
104
-
105
- The distillation method proposed in [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
106
- The LoRAs can be added **without** any retraining for similar results in most cases. Feel free to tweak the parameters and use your own LoRAs by giving a look at the [Github Repo](https://github.com/gojasper/flash-diffusion)
107
- """
108
- )
109
- gr.Markdown(
110
- "If you enjoy the space, please also promote *open-source* by giving a ⭐ to our repo [![GitHub Stars](https://img.shields.io/github/stars/gojasper/flash-diffusion?style=social)](https://github.com/gojasper/flash-diffusion)"
 
 
111
  )
112
 
113
- # Index of selected LoRA
114
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
115
- # Serve as memory for currently loaded lora in pipe
116
  gr_lora_id = gr.State(value="")
117
 
118
  with gr.Row():
119
- with gr.Blocks():
120
- with gr.Column():
121
- user_lora_selector = gr.Textbox(
122
- label="Current Selected LoRA",
123
- max_lines=1,
124
- interactive=False,
125
- )
126
-
127
- user_lora_weight = gr.Slider(
128
- label="Selected LoRA Weight",
129
- minimum=0.5,
130
- maximum=3,
131
- step=0.1,
132
- value=1,
133
- )
 
 
 
 
 
 
 
 
134
 
135
- gallery = gr.Gallery(
136
- value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
137
- label="SDXL LoRA Gallery",
138
- allow_preview=False,
139
- columns=3,
140
- elem_id="gallery",
141
- show_share_button=False,
142
- )
143
-
144
- with gr.Column():
145
  with gr.Row():
146
- prompt = gr.Text(
147
- label="Prompt",
148
- show_label=False,
149
- max_lines=1,
150
- placeholder="Enter your prompt",
151
- container=False,
152
- scale=5,
153
- )
154
 
155
- run_button = gr.Button("Run", scale=1)
156
-
157
- result = gr.Image(label="Result", show_label=False)
158
 
159
  with gr.Accordion("Advanced Settings", open=False):
160
- pre_prompt = gr.Text(
161
  label="Pre-Prompt",
162
- show_label=True,
163
- max_lines=1,
164
  placeholder="Pre Prompt from the LoRA config",
165
- container=True,
166
- scale=5,
167
  )
168
 
169
- seed = gr.Slider(
170
- label="Seed",
171
- minimum=0,
172
- maximum=MAX_SEED,
173
- step=1,
174
- value=0,
175
- )
176
-
177
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
178
-
179
  with gr.Row():
180
- num_inference_steps = gr.Slider(
181
- label="Number of inference steps",
182
- minimum=4,
183
- maximum=8,
184
  step=1,
185
- value=4,
186
  )
 
187
 
188
- with gr.Row():
189
- guidance_scale = gr.Slider(
190
- label="Guidance Scale",
191
- minimum=1,
192
- maximum=6,
193
- step=0.5,
194
- value=1,
195
- )
196
 
197
- hint_negative = gr.Markdown(
198
- """πŸ’‘ _Hint : Negative Prompt will only work with Guidance > 1 but the model was
199
- trained to be used with guidance = 1 (ie. without guidance).
200
- Can degrade the results, use cautiously._"""
 
 
201
  )
202
 
203
- negative_prompt = gr.Text(
204
  label="Negative Prompt",
205
- show_label=False,
206
- max_lines=1,
207
  placeholder="Enter a negative Prompt",
208
- container=False,
209
  )
210
 
211
  gr.on(
212
- [
213
- run_button.click,
214
- seed.change,
215
- randomize_seed.change,
216
- prompt.submit,
217
- negative_prompt.change,
218
- negative_prompt.submit,
219
- guidance_scale.change,
220
- ],
221
  fn=infer,
222
  inputs=[
223
  pre_prompt,
@@ -228,24 +225,30 @@ with gr.Blocks(css=css) as demo:
228
  negative_prompt,
229
  guidance_scale,
230
  user_lora_selector,
231
- user_lora_weight,
232
  ],
233
  outputs=[result],
234
  )
235
 
 
 
236
  gallery.select(
237
  fn=update_selection,
238
  inputs=[gr_sdxl_loras],
239
- outputs=[
240
- user_lora_selector,
241
- pre_prompt,
242
- ],
243
- show_progress="hidden",
244
  )
245
 
246
- gr.Markdown("**Disclaimer:**")
247
  gr.Markdown(
248
- "This demo is only for research purpose. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards."
 
 
 
 
 
 
 
 
 
249
  )
250
 
251
  demo.queue().launch()
 
1
  import json
2
  import random
3
+ import requests
4
  import gradio as gr
5
  import numpy as np
6
  import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
+ from PIL import Image
10
 
11
  # Load the JSON data
12
  with open("sdxl_lora.json", "r") as file:
13
  data = json.load(file)
14
+ sdxl_loras_raw = sorted(data, key=lambda x: x["likes"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
 
29
  return lora_id, trigger_word
30
 
31
  def load_lora_for_style(style_repo):
32
+ pipe.unload_lora_weights()
33
  pipe.load_lora_weights(style_repo, adapter_name="lora")
34
 
35
+ def get_image(image_data):
36
+ if isinstance(image_data, str):
37
+ return image_data
38
+
39
+ if isinstance(image_data, dict):
40
+ local_path = image_data.get('local_path')
41
+ hf_url = image_data.get('hf_url')
42
+ else:
43
+ return None # or a default image path
44
+
45
+ try:
46
+ return local_path # Return the local path string
47
+ except:
48
+ try:
49
+ response = requests.get(hf_url)
50
+ if response.status_code == 200:
51
+ with open(local_path, 'wb') as f:
52
+ f.write(response.content)
53
+ return local_path # Return the local path string
54
+ except Exception as e:
55
+ print(f"Failed to load image: {e}")
56
+
57
+ return None # or a default image path
58
+
59
  @spaces.GPU
60
  def infer(
61
  pre_prompt,
 
69
  user_lora_weight,
70
  progress=gr.Progress(track_tqdm=True),
71
  ):
 
72
  load_lora_for_style(user_lora_selector)
73
 
74
  if randomize_seed:
 
90
  return image
91
 
92
  css = """
93
+ body {
94
+ background-color: #1a1a1a;
95
+ color: #ffffff;
96
+ }
97
+ .container {
98
+ max-width: 900px;
99
+ margin: auto;
100
+ padding: 20px;
101
+ }
102
+ h1, h2 {
103
+ color: #4CAF50;
104
  text-align: center;
 
105
  }
106
+ .gallery {
107
+ display: flex;
108
+ flex-wrap: wrap;
109
+ justify-content: center;
110
+ }
111
+ .gallery img {
112
+ margin: 10px;
113
+ border-radius: 10px;
114
+ transition: transform 0.3s ease;
115
+ }
116
+ .gallery img:hover {
117
+ transform: scale(1.05);
118
+ }
119
+ .gradio-slider input[type="range"] {
120
+ background-color: #4CAF50;
121
+ }
122
+ .gradio-button {
123
+ background-color: #4CAF50 !important;
124
  }
125
  """
126
 
 
 
 
 
 
127
  with gr.Blocks(css=css) as demo:
128
  gr.Markdown(
129
+ """
130
+ # ⚑ FlashDiffusion: Araminta K's FlashLoRA Showcase ⚑
131
+
132
+ This interactive demo showcases [Araminta K's models](https://huggingface.co/alvdansen) using [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/) technology.
133
+
134
+ ## Acknowledgments
135
+ - Original Flash Diffusion technology by the Jasper AI team
136
+ - Based on the paper: [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) by ClΓ©ment Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin
137
+ - Models showcased here are created by Araminta K at Alvdansen Labs
138
+
139
+ Explore the power of FlashLoRA with Araminta K's unique artistic styles!
140
+ """
141
  )
142
 
 
143
  gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
 
144
  gr_lora_id = gr.State(value="")
145
 
146
  with gr.Row():
147
+ with gr.Column(scale=2):
148
+ gallery = gr.Gallery(
149
+ value=[(img, title) for img, title in
150
+ ((get_image(item["image"]), item["title"]) for item in sdxl_loras_raw)
151
+ if img is not None],
152
+ label="SDXL LoRA Gallery",
153
+ show_label=False,
154
+ elem_id="gallery",
155
+ columns=3,
156
+ height=600,
157
+ )
158
+
159
+ user_lora_selector = gr.Textbox(
160
+ label="Current Selected LoRA",
161
+ interactive=False,
162
+ )
163
+
164
+ with gr.Column(scale=3):
165
+ prompt = gr.Textbox(
166
+ label="Prompt",
167
+ placeholder="Enter your prompt",
168
+ lines=3,
169
+ )
170
 
 
 
 
 
 
 
 
 
 
 
171
  with gr.Row():
172
+ run_button = gr.Button("Run", variant="primary")
173
+ clear_button = gr.Button("Clear")
 
 
 
 
 
 
174
 
175
+ result = gr.Image(label="Result", height=512)
 
 
176
 
177
  with gr.Accordion("Advanced Settings", open=False):
178
+ pre_prompt = gr.Textbox(
179
  label="Pre-Prompt",
 
 
180
  placeholder="Pre Prompt from the LoRA config",
181
+ lines=2,
 
182
  )
183
 
 
 
 
 
 
 
 
 
 
 
184
  with gr.Row():
185
+ seed = gr.Slider(
186
+ label="Seed",
187
+ minimum=0,
188
+ maximum=MAX_SEED,
189
  step=1,
190
+ value=0,
191
  )
192
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
193
 
194
+ num_inference_steps = gr.Slider(
195
+ label="Number of inference steps",
196
+ minimum=4,
197
+ maximum=8,
198
+ step=1,
199
+ value=4,
200
+ )
 
201
 
202
+ guidance_scale = gr.Slider(
203
+ label="Guidance Scale",
204
+ minimum=1,
205
+ maximum=6,
206
+ step=0.5,
207
+ value=1,
208
  )
209
 
210
+ negative_prompt = gr.Textbox(
211
  label="Negative Prompt",
 
 
212
  placeholder="Enter a negative Prompt",
213
+ lines=2,
214
  )
215
 
216
  gr.on(
217
+ [run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
218
  fn=infer,
219
  inputs=[
220
  pre_prompt,
 
225
  negative_prompt,
226
  guidance_scale,
227
  user_lora_selector,
228
+ gr.Slider(label="Selected LoRA Weight", minimum=0.5, maximum=3, step=0.1, value=1),
229
  ],
230
  outputs=[result],
231
  )
232
 
233
+ clear_button.click(lambda: "", outputs=[prompt, result])
234
+
235
  gallery.select(
236
  fn=update_selection,
237
  inputs=[gr_sdxl_loras],
238
+ outputs=[user_lora_selector, pre_prompt],
 
 
 
 
239
  )
240
 
 
241
  gr.Markdown(
242
+ """
243
+ ## Unleash Your Creativity!
244
+
245
+ This showcase brings together the speed of Flash Diffusion and the artistic flair of Araminta K's models.
246
+ Craft your prompts, adjust the settings, and watch as AI brings your ideas to life in stunning detail.
247
+
248
+ Remember to use this tool ethically and respect copyright and individual privacy.
249
+
250
+ Enjoy exploring these unique artistic styles!
251
+ """
252
  )
253
 
254
  demo.queue().launch()