adamelliotfields commited on
Commit
b7fd57e
1 Parent(s): c851587
Files changed (3) hide show
  1. app.py +61 -45
  2. generate.py +83 -64
  3. requirements.txt +4 -3
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
 
5
  from generate import generate
6
 
7
- DEFAULT_NEGATIVE_PROMPT = "<bad_prompt>, ugly, unattractive, deformed, disfigured, mutated, blurry, distorted, noisy, grainy, glitch, worst quality"
8
 
9
  # base font stacks
10
  MONO_FONTS = ["monospace"]
@@ -60,6 +60,7 @@ with gr.Blocks(
60
  font=[gr.themes.GoogleFont("Inter"), *SANS_FONTS],
61
  font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *MONO_FONTS],
62
  ).set(
 
63
  block_shadow="0 0 #0000",
64
  block_shadow_dark="0 0 #0000",
65
  block_background_fill=gr.themes.colors.gray.c50,
@@ -67,28 +68,49 @@ with gr.Blocks(
67
  ),
68
  ) as demo:
69
  gr.HTML(read_file("./partials/intro.html"))
70
- output_images = gr.Gallery(
71
- elem_classes=["gallery"],
72
- show_share_button=False,
73
- interactive=False,
74
- show_label=False,
75
- label="Output",
76
- format="png",
77
- columns=2,
78
- )
79
- prompt = gr.Textbox(
80
- placeholder="corgi, at the beach, cute, 8k",
81
- show_label=False,
82
- label="Prompt",
83
- value=None,
84
- lines=2,
85
- )
86
- generate_btn = gr.Button("Generate", variant="primary", elem_classes=[])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  with gr.Accordion(
89
  elem_classes=["accordion"],
90
  elem_id="menu",
91
- label="Menu",
92
  open=False,
93
  ):
94
  with gr.Tabs():
@@ -98,12 +120,12 @@ with gr.Blocks(
98
  label="Negative Prompt",
99
  value=DEFAULT_NEGATIVE_PROMPT,
100
  placeholder="",
101
- lines=1,
102
  )
103
 
104
  with gr.Row():
105
  num_images = gr.Dropdown(
106
- choices=[1, 2, 3, 4],
107
  filterable=False,
108
  label="Images",
109
  value=1,
@@ -113,7 +135,7 @@ with gr.Blocks(
113
  label="Width",
114
  minimum=256,
115
  maximum=1024,
116
- value=512,
117
  step=32,
118
  scale=2,
119
  )
@@ -121,7 +143,7 @@ with gr.Blocks(
121
  label="Height",
122
  minimum=256,
123
  maximum=1024,
124
- value=512,
125
  step=32,
126
  scale=2,
127
  )
@@ -131,7 +153,7 @@ with gr.Blocks(
131
  label="Guidance Scale",
132
  minimum=1.0,
133
  maximum=15.0,
134
- value=7.5,
135
  step=0.1,
136
  )
137
  inference_steps = gr.Slider(
@@ -171,7 +193,7 @@ with gr.Blocks(
171
  "PNDM",
172
  ],
173
  )
174
- seed = gr.Number(label="Seed", value=42)
175
 
176
  with gr.Row():
177
  use_karras = gr.Checkbox(
@@ -184,39 +206,33 @@ with gr.Blocks(
184
  elem_classes=["checkbox"],
185
  label="Autoincrement",
186
  value=True,
187
- scale=2,
188
- )
189
- random_seed_btn = gr.Button(
190
- "🎲 Random seed",
191
- variant="secondary",
192
- size="sm",
193
- scale=1,
194
  )
195
 
196
  with gr.TabItem("🛠️ Advanced"):
197
  with gr.Group():
198
  with gr.Row():
199
- deep_cache_interval = gr.Slider(
200
  label="DeepCache Interval",
201
  minimum=1,
202
  maximum=4,
203
- value=0,
204
- step=1,
205
- )
206
- deep_cache_branch = gr.Slider(
207
- label="DeepCache Branch",
208
- minimum=0,
209
- maximum=3,
210
- value=0,
211
  step=1,
212
  )
213
  tgate_step = gr.Slider(
214
  label="T-GATE Step",
215
  minimum=0,
216
  maximum=50,
217
- value=0,
218
  step=1,
219
  )
 
 
 
 
 
 
 
220
 
221
  with gr.Row():
222
  use_taesd = gr.Checkbox(
@@ -242,7 +258,7 @@ with gr.Blocks(
242
  gr.Markdown(read_file("info.md"), elem_classes=["markdown"])
243
 
244
  # update the random seed using JavaScript
245
- random_seed_btn.click(None, outputs=[seed], js="() => Math.floor(Math.random() * 2**32)")
246
 
247
  # ensure correct argument order
248
  generate_btn.click(
@@ -266,9 +282,9 @@ with gr.Blocks(
266
  use_clip_skip,
267
  truncate_prompts,
268
  increment_seed,
269
- deep_cache_interval,
270
- deep_cache_branch,
271
  tgate_step,
 
272
  ],
273
  )
274
 
 
4
 
5
  from generate import generate
6
 
7
+ DEFAULT_NEGATIVE_PROMPT = "<fast_negative>"
8
 
9
  # base font stacks
10
  MONO_FONTS = ["monospace"]
 
60
  font=[gr.themes.GoogleFont("Inter"), *SANS_FONTS],
61
  font_mono=[gr.themes.GoogleFont("Ubuntu Mono"), *MONO_FONTS],
62
  ).set(
63
+ layout_gap="8px",
64
  block_shadow="0 0 #0000",
65
  block_shadow_dark="0 0 #0000",
66
  block_background_fill=gr.themes.colors.gray.c50,
 
68
  ),
69
  ) as demo:
70
  gr.HTML(read_file("./partials/intro.html"))
71
+
72
+ with gr.Group():
73
+ output_images = gr.Gallery(
74
+ elem_classes=["gallery"],
75
+ show_share_button=False,
76
+ interactive=False,
77
+ show_label=False,
78
+ label="Output",
79
+ format="png",
80
+ columns=2,
81
+ )
82
+ prompt = gr.Textbox(
83
+ placeholder="corgi, at the beach, cute, 8k",
84
+ show_label=False,
85
+ label="Prompt",
86
+ value=None,
87
+ lines=2,
88
+ )
89
+
90
+ with gr.Row():
91
+ generate_btn = gr.Button("Generate", variant="primary", scale=6, elem_classes=[])
92
+ random_btn = gr.Button(
93
+ elem_classes=["icon-button"],
94
+ variant="secondary",
95
+ elem_id="random",
96
+ min_width=0,
97
+ value="🎲",
98
+ scale=1,
99
+ )
100
+ clear_btn = gr.ClearButton(
101
+ elem_classes=["icon-button"],
102
+ components=[output_images],
103
+ variant="secondary",
104
+ elem_id="clear",
105
+ min_width=0,
106
+ value="🗑️",
107
+ scale=1,
108
+ )
109
 
110
  with gr.Accordion(
111
  elem_classes=["accordion"],
112
  elem_id="menu",
113
+ label="Open menu",
114
  open=False,
115
  ):
116
  with gr.Tabs():
 
120
  label="Negative Prompt",
121
  value=DEFAULT_NEGATIVE_PROMPT,
122
  placeholder="",
123
+ lines=2,
124
  )
125
 
126
  with gr.Row():
127
  num_images = gr.Dropdown(
128
+ choices=list(range(1, 9)),
129
  filterable=False,
130
  label="Images",
131
  value=1,
 
135
  label="Width",
136
  minimum=256,
137
  maximum=1024,
138
+ value=448,
139
  step=32,
140
  scale=2,
141
  )
 
143
  label="Height",
144
  minimum=256,
145
  maximum=1024,
146
+ value=576,
147
  step=32,
148
  scale=2,
149
  )
 
153
  label="Guidance Scale",
154
  minimum=1.0,
155
  maximum=15.0,
156
+ value=7,
157
  step=0.1,
158
  )
159
  inference_steps = gr.Slider(
 
193
  "PNDM",
194
  ],
195
  )
196
+ seed = gr.Number(label="Seed", value=42, scale=1)
197
 
198
  with gr.Row():
199
  use_karras = gr.Checkbox(
 
206
  elem_classes=["checkbox"],
207
  label="Autoincrement",
208
  value=True,
209
+ scale=4,
 
 
 
 
 
 
210
  )
211
 
212
  with gr.TabItem("🛠️ Advanced"):
213
  with gr.Group():
214
  with gr.Row():
215
+ deepcache_interval = gr.Slider(
216
  label="DeepCache Interval",
217
  minimum=1,
218
  maximum=4,
219
+ value=2,
 
 
 
 
 
 
 
220
  step=1,
221
  )
222
  tgate_step = gr.Slider(
223
  label="T-GATE Step",
224
  minimum=0,
225
  maximum=50,
226
+ value=20,
227
  step=1,
228
  )
229
+ tome_ratio = gr.Slider(
230
+ label="ToMe Ratio",
231
+ minimum=0.0,
232
+ maximum=1.0,
233
+ value=0.0,
234
+ step=0.01,
235
+ )
236
 
237
  with gr.Row():
238
  use_taesd = gr.Checkbox(
 
258
  gr.Markdown(read_file("info.md"), elem_classes=["markdown"])
259
 
260
  # update the random seed using JavaScript
261
+ random_btn.click(None, outputs=[seed], js="() => Math.floor(Math.random() * 2**32)")
262
 
263
  # ensure correct argument order
264
  generate_btn.click(
 
282
  use_clip_skip,
283
  truncate_prompts,
284
  increment_seed,
285
+ deepcache_interval,
 
286
  tgate_step,
287
+ tome_ratio,
288
  ],
289
  )
290
 
generate.py CHANGED
@@ -7,6 +7,7 @@ from types import MethodType
7
  from warnings import filterwarnings
8
 
9
  import spaces
 
10
  import torch
11
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
12
  from DeepCache import DeepCacheSDHelper
@@ -54,27 +55,63 @@ class Loader:
54
  cls._instance.pipe = None
55
  return cls._instance
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def _load_vae(self, model_name=None, taesd=False, dtype=None):
58
- if taesd:
 
 
 
 
 
59
  # can't compile tiny VAE
60
- return AutoencoderTiny.from_pretrained(
 
61
  pretrained_model_name_or_path="madebyollin/taesd",
62
  use_safetensors=True,
63
  torch_dtype=dtype,
64
  ).to(self.gpu)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- return torch.compile(
67
- fullgraph=True,
68
- mode="reduce-overhead",
69
- model=AutoencoderKL.from_pretrained(
70
- pretrained_model_name_or_path=model_name,
71
- use_safetensors=True,
72
- torch_dtype=dtype,
73
- subfolder="vae",
74
- ).to(self.gpu),
75
- )
76
-
77
- def load(self, model, scheduler, karras, taesd, dtype=None):
78
  model_lower = model.lower()
79
 
80
  schedulers = {
@@ -126,13 +163,9 @@ class Loader:
126
  if not same_scheduler or not same_karras:
127
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
128
 
129
- # if compiled will be an OptimizedModule
130
- vae_type = type(self.pipe.vae)
131
- if (issubclass(vae_type, (AutoencoderKL, OptimizedModule)) and taesd) or (
132
- issubclass(vae_type, AutoencoderTiny) and not taesd
133
- ):
134
- print(f"Switching to {'Tiny' if taesd else 'KL'} VAE...")
135
- self.pipe.vae = self._load_vae(model_lower, taesd, dtype)
136
  return self.pipe
137
  else:
138
  print(f"Unloading {model_name.lower()}...")
@@ -149,7 +182,9 @@ class Loader:
149
 
150
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
151
  self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
152
- self.pipe.vae = self._load_vae(model_lower, taesd, dtype)
 
 
153
  self.pipe.load_textual_inversion(
154
  pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
155
  tokens=list(EMBEDDINGS.values()),
@@ -157,26 +192,15 @@ class Loader:
157
  return self.pipe
158
 
159
 
 
160
  @contextmanager
161
- def deep_cache(pipe, interval=1, branch=0, tgate_step=0):
162
- if interval > 1:
163
- helper = DeepCacheSDHelper(pipe=pipe)
164
- helper.set_params(cache_interval=interval, cache_branch_id=branch)
165
- helper.enable()
166
-
167
- if tgate_step > 0:
168
- pipe.deepcache = helper
169
- pipe.tgate = MethodType(tgate_sd_deepcache, pipe)
170
-
171
- try:
172
- yield helper
173
- finally:
174
- helper.disable()
175
- elif interval < 2 and tgate_step > 0:
176
- pipe.tgate = MethodType(tgate_sd, pipe)
177
- yield None
178
- else:
179
- yield None
180
 
181
 
182
  # parse prompts with arrays
@@ -194,7 +218,6 @@ def parse_prompt(prompt: str) -> list[str]:
194
  current_prompt = prompt
195
  for i, token in enumerate(combo):
196
  current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
197
-
198
  prompts.append(current_prompt)
199
  return prompts
200
 
@@ -216,9 +239,9 @@ def generate(
216
  clip_skip=False,
217
  truncate_prompts=False,
218
  increment_seed=True,
219
- deep_cache_interval=1,
220
- deep_cache_branch=0,
221
  tgate_step=0,
 
222
  Error=Exception,
223
  ):
224
  if not torch.cuda.is_available():
@@ -241,7 +264,7 @@ def generate(
241
 
242
  with torch.inference_mode():
243
  loader = Loader()
244
- pipe = loader.load(model, scheduler, karras, taesd, dtype=TORCH_DTYPE)
245
 
246
  # prompt embeds
247
  compel = Compel(
@@ -271,25 +294,21 @@ def generate(
271
  [pos_embeds, neg_embeds]
272
  )
273
 
274
- with deep_cache(
275
- pipe,
276
- interval=deep_cache_interval,
277
- branch=deep_cache_branch,
278
- tgate_step=tgate_step,
279
- ):
280
- pipe_kwargs = {
281
- "num_inference_steps": inference_steps,
282
- "negative_prompt_embeds": neg_embeds,
283
- "guidance_scale": guidance_scale,
284
- "prompt_embeds": pos_embeds,
285
- "generator": generator,
286
- "height": height,
287
- "width": width,
288
- }
289
- result = (
290
- pipe.tgate(**pipe_kwargs, gate_step=tgate_step)
291
- if tgate_step > 0
292
- else pipe(**pipe_kwargs)
293
  )
294
  images.append((result.images[0], str(current_seed)))
295
 
 
7
  from warnings import filterwarnings
8
 
9
  import spaces
10
+ import tomesd
11
  import torch
12
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
13
  from DeepCache import DeepCacheSDHelper
 
55
  cls._instance.pipe = None
56
  return cls._instance
57
 
58
+ def _load_deepcache(self, interval=1):
59
+ has_deepcache = hasattr(self.pipe, "deepcache")
60
+
61
+ if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
62
+ return self.pipe.deepcache
63
+ if has_deepcache:
64
+ self.pipe.deepcache.disable()
65
+ else:
66
+ self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
67
+
68
+ self.pipe.deepcache.set_params(cache_interval=interval)
69
+ self.pipe.deepcache.enable()
70
+ return self.pipe.deepcache
71
+
72
+ def _load_tgate(self):
73
+ has_tgate = hasattr(self.pipe, "tgate")
74
+ has_deepcache = hasattr(self.pipe, "deepcache")
75
+
76
+ if not has_tgate:
77
+ self.pipe.tgate = MethodType(
78
+ tgate_sd_deepcache if has_deepcache else tgate_sd,
79
+ self.pipe,
80
+ )
81
+
82
+ return self.pipe.tgate
83
+
84
  def _load_vae(self, model_name=None, taesd=False, dtype=None):
85
+ vae_type = type(self.pipe.vae)
86
+ is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
87
+ is_tiny = issubclass(vae_type, AutoencoderTiny)
88
+
89
+ # by default all models use KL
90
+ if is_kl and taesd:
91
  # can't compile tiny VAE
92
+ print("Switching to Tiny VAE...")
93
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
94
  pretrained_model_name_or_path="madebyollin/taesd",
95
  use_safetensors=True,
96
  torch_dtype=dtype,
97
  ).to(self.gpu)
98
+ return self.pipe.vae
99
+
100
+ if is_tiny and not taesd:
101
+ print("Switching to KL VAE...")
102
+ self.pipe.vae = torch.compile(
103
+ fullgraph=True,
104
+ mode="reduce-overhead",
105
+ model=AutoencoderKL.from_pretrained(
106
+ pretrained_model_name_or_path=model_name,
107
+ use_safetensors=True,
108
+ torch_dtype=dtype,
109
+ subfolder="vae",
110
+ ).to(self.gpu),
111
+ )
112
+ return self.pipe.vae
113
 
114
+ def load(self, model, scheduler, karras, taesd, deepcache_interval, dtype=None):
 
 
 
 
 
 
 
 
 
 
 
115
  model_lower = model.lower()
116
 
117
  schedulers = {
 
163
  if not same_scheduler or not same_karras:
164
  self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
165
 
166
+ self._load_vae(model_lower, taesd, dtype)
167
+ self._load_deepcache(interval=deepcache_interval)
168
+ self._load_tgate()
 
 
 
 
169
  return self.pipe
170
  else:
171
  print(f"Unloading {model_name.lower()}...")
 
182
 
183
  print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
184
  self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
185
+ self._load_vae(model_lower, taesd, dtype)
186
+ self._load_deepcache(interval=deepcache_interval)
187
+ self._load_tgate()
188
  self.pipe.load_textual_inversion(
189
  pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
190
  tokens=list(EMBEDDINGS.values()),
 
192
  return self.pipe
193
 
194
 
195
+ # applies tome to the pipeline
196
  @contextmanager
197
+ def token_merging(pipe, tome_ratio=0):
198
+ try:
199
+ if tome_ratio > 0:
200
+ tomesd.apply_patch(pipe, max_downsample=1, sx=2, sy=2, ratio=tome_ratio)
201
+ yield
202
+ finally:
203
+ tomesd.remove_patch(pipe) # idempotent
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  # parse prompts with arrays
 
218
  current_prompt = prompt
219
  for i, token in enumerate(combo):
220
  current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
 
221
  prompts.append(current_prompt)
222
  return prompts
223
 
 
239
  clip_skip=False,
240
  truncate_prompts=False,
241
  increment_seed=True,
242
+ deepcache_interval=1,
 
243
  tgate_step=0,
244
+ tome_ratio=0,
245
  Error=Exception,
246
  ):
247
  if not torch.cuda.is_available():
 
264
 
265
  with torch.inference_mode():
266
  loader = Loader()
267
+ pipe = loader.load(model, scheduler, karras, taesd, deepcache_interval, TORCH_DTYPE)
268
 
269
  # prompt embeds
270
  compel = Compel(
 
294
  [pos_embeds, neg_embeds]
295
  )
296
 
297
+ with token_merging(pipe, tome_ratio=tome_ratio):
298
+ # cap the tgate step
299
+ gate_step = min(
300
+ tgate_step if tgate_step > 0 else inference_steps,
301
+ inference_steps,
302
+ )
303
+ result = pipe.tgate(
304
+ num_inference_steps=inference_steps,
305
+ negative_prompt_embeds=neg_embeds,
306
+ guidance_scale=guidance_scale,
307
+ prompt_embeds=pos_embeds,
308
+ gate_step=gate_step,
309
+ generator=generator,
310
+ height=height,
311
+ width=width,
 
 
 
 
312
  )
313
  images.append((result.images[0], str(current_seed)))
314
 
requirements.txt CHANGED
@@ -1,13 +1,14 @@
1
  accelerate
2
  compel
3
- deepcache
4
  diffusers
5
  hf-transfer
6
- gradio
7
  ruff
8
  scipy # for LMS scheduler
9
  spaces
10
- tgate
 
11
  torch
12
  torchvision
13
  transformers
 
1
  accelerate
2
  compel
3
+ deepcache==0.1.1
4
  diffusers
5
  hf-transfer
6
+ gradio==4.39.0
7
  ruff
8
  scipy # for LMS scheduler
9
  spaces
10
+ tgate==0.1.2
11
+ tomesd==0.1.3
12
  torch
13
  torchvision
14
  transformers