Spaces:
Runtime error
Runtime error
reduce memory footprint
Browse files- app.py +161 -161
- models/region_diffusion_xl.py +11 -6
app.py
CHANGED
@@ -260,45 +260,114 @@ def main():
|
|
260 |
with gr.Row():
|
261 |
gr.Markdown(help_text)
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
with gr.Row():
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
[
|
266 |
-
'{"ops":[{"insert":"
|
267 |
-
'',
|
268 |
-
|
|
|
269 |
0.3,
|
270 |
0.3,
|
|
|
271 |
0.5,
|
272 |
-
3,
|
273 |
-
0,
|
274 |
None,
|
275 |
],
|
276 |
[
|
277 |
-
'{"ops":[{"insert":"
|
278 |
'',
|
279 |
-
|
280 |
-
0.
|
281 |
-
0.3,
|
282 |
0.5,
|
283 |
-
3,
|
284 |
-
0,
|
285 |
-
None,
|
286 |
-
],
|
287 |
-
[
|
288 |
-
'{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
|
289 |
-
'',
|
290 |
-
5,
|
291 |
0.3,
|
292 |
-
|
293 |
-
0.
|
294 |
-
4,
|
295 |
-
0,
|
296 |
None,
|
297 |
],
|
298 |
]
|
299 |
-
|
300 |
-
|
301 |
-
label='Footnote examples',
|
302 |
inputs=[
|
303 |
text_input,
|
304 |
negative_prompt,
|
@@ -319,55 +388,93 @@ def main():
|
|
319 |
fn=generate,
|
320 |
cache_examples=True,
|
321 |
examples_per_page=20)
|
|
|
322 |
# with gr.Row():
|
323 |
-
#
|
324 |
# [
|
325 |
-
# '{"ops":[{"insert":"a
|
326 |
-
# '
|
327 |
-
#
|
328 |
-
# 0.
|
329 |
-
# 0
|
330 |
-
# 0.
|
331 |
-
#
|
332 |
-
# 0
|
333 |
# None,
|
334 |
# ],
|
335 |
# [
|
336 |
-
# '{"ops":[{"insert":"a
|
337 |
-
# '
|
338 |
-
#
|
339 |
-
# 0.
|
340 |
-
# 0
|
341 |
-
# 0
|
342 |
# 6,
|
343 |
# 0.5,
|
344 |
# None,
|
345 |
# ],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
# [
|
347 |
-
# '{"ops":[{"insert":"
|
348 |
-
# '
|
349 |
-
#
|
350 |
-
# 0.5,
|
351 |
# 0.3,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
# 0.3,
|
353 |
-
#
|
354 |
-
# 0
|
|
|
|
|
355 |
# None,
|
356 |
# ],
|
357 |
# [
|
358 |
-
# '{"ops":[{"insert":"
|
359 |
# '',
|
360 |
-
#
|
361 |
-
# 0.5,
|
362 |
-
# 0.5,
|
363 |
# 0.3,
|
364 |
-
#
|
365 |
-
# 0
|
|
|
|
|
366 |
# None,
|
367 |
# ],
|
368 |
# ]
|
369 |
-
# gr.Examples(examples=
|
370 |
-
# label='Font
|
371 |
# inputs=[
|
372 |
# text_input,
|
373 |
# negative_prompt,
|
@@ -388,113 +495,6 @@ def main():
|
|
388 |
# fn=generate,
|
389 |
# cache_examples=True,
|
390 |
# examples_per_page=20)
|
391 |
-
|
392 |
-
with gr.Row():
|
393 |
-
style_examples = [
|
394 |
-
[
|
395 |
-
'{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
|
396 |
-
'',
|
397 |
-
10,
|
398 |
-
0.6,
|
399 |
-
0,
|
400 |
-
0.4,
|
401 |
-
5,
|
402 |
-
0,
|
403 |
-
None,
|
404 |
-
],
|
405 |
-
[
|
406 |
-
'{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
|
407 |
-
'',
|
408 |
-
2,
|
409 |
-
0.6,
|
410 |
-
0,
|
411 |
-
0,
|
412 |
-
6,
|
413 |
-
0.5,
|
414 |
-
None,
|
415 |
-
],
|
416 |
-
]
|
417 |
-
gr.Examples(examples=style_examples,
|
418 |
-
label='Font style examples',
|
419 |
-
inputs=[
|
420 |
-
text_input,
|
421 |
-
negative_prompt,
|
422 |
-
num_segments,
|
423 |
-
segment_threshold,
|
424 |
-
inject_interval,
|
425 |
-
inject_background,
|
426 |
-
seed,
|
427 |
-
color_guidance_weight,
|
428 |
-
rich_text_input,
|
429 |
-
],
|
430 |
-
outputs=[
|
431 |
-
plaintext_result,
|
432 |
-
richtext_result,
|
433 |
-
segments,
|
434 |
-
token_map,
|
435 |
-
],
|
436 |
-
fn=generate,
|
437 |
-
cache_examples=True,
|
438 |
-
examples_per_page=20)
|
439 |
-
|
440 |
-
with gr.Row():
|
441 |
-
size_examples = [
|
442 |
-
[
|
443 |
-
'{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
|
444 |
-
'',
|
445 |
-
5,
|
446 |
-
0.3,
|
447 |
-
0,
|
448 |
-
0,
|
449 |
-
3,
|
450 |
-
1,
|
451 |
-
None,
|
452 |
-
],
|
453 |
-
[
|
454 |
-
'{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
|
455 |
-
'',
|
456 |
-
5,
|
457 |
-
0.3,
|
458 |
-
0,
|
459 |
-
0,
|
460 |
-
3,
|
461 |
-
1,
|
462 |
-
None,
|
463 |
-
],
|
464 |
-
[
|
465 |
-
'{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
|
466 |
-
'',
|
467 |
-
5,
|
468 |
-
0.3,
|
469 |
-
0,
|
470 |
-
0,
|
471 |
-
3,
|
472 |
-
1,
|
473 |
-
None,
|
474 |
-
],
|
475 |
-
]
|
476 |
-
gr.Examples(examples=size_examples,
|
477 |
-
label='Font size examples',
|
478 |
-
inputs=[
|
479 |
-
text_input,
|
480 |
-
negative_prompt,
|
481 |
-
num_segments,
|
482 |
-
segment_threshold,
|
483 |
-
inject_interval,
|
484 |
-
inject_background,
|
485 |
-
seed,
|
486 |
-
color_guidance_weight,
|
487 |
-
rich_text_input,
|
488 |
-
],
|
489 |
-
outputs=[
|
490 |
-
plaintext_result,
|
491 |
-
richtext_result,
|
492 |
-
segments,
|
493 |
-
token_map,
|
494 |
-
],
|
495 |
-
fn=generate,
|
496 |
-
cache_examples=True,
|
497 |
-
examples_per_page=20)
|
498 |
generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
|
499 |
fn=generate,
|
500 |
inputs=[
|
|
|
260 |
with gr.Row():
|
261 |
gr.Markdown(help_text)
|
262 |
|
263 |
+
# with gr.Row():
|
264 |
+
# footnote_examples = [
|
265 |
+
# [
|
266 |
+
# '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background."}]}',
|
267 |
+
# '',
|
268 |
+
# 9,
|
269 |
+
# 0.3,
|
270 |
+
# 0.3,
|
271 |
+
# 0.5,
|
272 |
+
# 3,
|
273 |
+
# 0,
|
274 |
+
# None,
|
275 |
+
# ],
|
276 |
+
# [
|
277 |
+
# '{"ops":[{"insert":"A cozy "},{"attributes":{"link":"A charming wooden cabin with Christmas decoration, warm light coming out from the windows."},"insert":"cabin"},{"insert":" nestled in a "},{"attributes":{"link":"Towering evergreen trees covered in a thick layer of pristine snow."},"insert":"snowy forest"},{"insert":", and a "},{"attributes":{"link":"A cute snowman wearing a carrot nose, coal eyes, and a colorful scarf, welcoming visitors with a cheerful vibe."},"insert":"snowman"},{"insert":" stands in the yard."}]}',
|
278 |
+
# '',
|
279 |
+
# 12,
|
280 |
+
# 0.4,
|
281 |
+
# 0.3,
|
282 |
+
# 0.5,
|
283 |
+
# 3,
|
284 |
+
# 0,
|
285 |
+
# None,
|
286 |
+
# ],
|
287 |
+
# [
|
288 |
+
# '{"ops":[{"insert":"A "},{"attributes":{"link":"Happy Kung fu panda art, elder, asian art, volumetric lighting, dramatic scene, ultra detailed, realism, chinese"},"insert":"panda"},{"insert":" standing on a cliff by a waterfall, wildlife photography, photograph, high quality, wildlife, f 1.8, soft focus, 8k, national geographic, award - winning photograph by nick nichols"}]}',
|
289 |
+
# '',
|
290 |
+
# 5,
|
291 |
+
# 0.3,
|
292 |
+
# 0,
|
293 |
+
# 0.1,
|
294 |
+
# 4,
|
295 |
+
# 0,
|
296 |
+
# None,
|
297 |
+
# ],
|
298 |
+
# ]
|
299 |
+
|
300 |
+
# gr.Examples(examples=footnote_examples,
|
301 |
+
# label='Footnote examples',
|
302 |
+
# inputs=[
|
303 |
+
# text_input,
|
304 |
+
# negative_prompt,
|
305 |
+
# num_segments,
|
306 |
+
# segment_threshold,
|
307 |
+
# inject_interval,
|
308 |
+
# inject_background,
|
309 |
+
# seed,
|
310 |
+
# color_guidance_weight,
|
311 |
+
# rich_text_input,
|
312 |
+
# ],
|
313 |
+
# outputs=[
|
314 |
+
# plaintext_result,
|
315 |
+
# richtext_result,
|
316 |
+
# segments,
|
317 |
+
# token_map,
|
318 |
+
# ],
|
319 |
+
# fn=generate,
|
320 |
+
# cache_examples=True,
|
321 |
+
# examples_per_page=20)
|
322 |
with gr.Row():
|
323 |
+
color_examples = [
|
324 |
+
# [
|
325 |
+
# '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#04a704"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
326 |
+
# 'lowres, had anatomy, bad hands, cropped, worst quality',
|
327 |
+
# 11,
|
328 |
+
# 0.5,
|
329 |
+
# 0.3,
|
330 |
+
# 0.3,
|
331 |
+
# 6,
|
332 |
+
# 0.5,
|
333 |
+
# None,
|
334 |
+
# ],
|
335 |
+
# [
|
336 |
+
# '{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#ff5df1"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
337 |
+
# 'lowres, had anatomy, bad hands, cropped, worst quality',
|
338 |
+
# 11,
|
339 |
+
# 0.5,
|
340 |
+
# 0.3,
|
341 |
+
# 0.3,
|
342 |
+
# 6,
|
343 |
+
# 0.5,
|
344 |
+
# None,
|
345 |
+
# ],
|
346 |
[
|
347 |
+
'{"ops":[{"insert":"a beautifule girl with big eye, skin, and long "},{"attributes":{"color":"#999999"},"insert":"hair"},{"insert":", t-shirt, bursting with vivid color, intricate, elegant, highly detailed, photorealistic, digital painting, artstation, illustration, concept art."}]}',
|
348 |
+
'lowres, had anatomy, bad hands, cropped, worst quality',
|
349 |
+
11,
|
350 |
+
0.5,
|
351 |
0.3,
|
352 |
0.3,
|
353 |
+
6,
|
354 |
0.5,
|
|
|
|
|
355 |
None,
|
356 |
],
|
357 |
[
|
358 |
+
'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#FD6C9E"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background."}]}',
|
359 |
'',
|
360 |
+
10,
|
361 |
+
0.5,
|
|
|
362 |
0.5,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
0.3,
|
364 |
+
7,
|
365 |
+
0.5,
|
|
|
|
|
366 |
None,
|
367 |
],
|
368 |
]
|
369 |
+
gr.Examples(examples=color_examples,
|
370 |
+
label='Font color examples',
|
|
|
371 |
inputs=[
|
372 |
text_input,
|
373 |
negative_prompt,
|
|
|
388 |
fn=generate,
|
389 |
cache_examples=True,
|
390 |
examples_per_page=20)
|
391 |
+
|
392 |
# with gr.Row():
|
393 |
+
# style_examples = [
|
394 |
# [
|
395 |
+
# '{"ops":[{"insert":"a beautiful"},{"attributes":{"font":"mirza"},"insert":" garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain"},{"insert":" in the background"}]}',
|
396 |
+
# '',
|
397 |
+
# 10,
|
398 |
+
# 0.6,
|
399 |
+
# 0,
|
400 |
+
# 0.4,
|
401 |
+
# 5,
|
402 |
+
# 0,
|
403 |
# None,
|
404 |
# ],
|
405 |
# [
|
406 |
+
# '{"ops":[{"insert":"a night"},{"attributes":{"font":"slabo"},"insert":" sky"},{"insert":" filled with stars above a turbulent"},{"attributes":{"font":"roboto"},"insert":" sea"},{"insert":" with giant waves"}]}',
|
407 |
+
# '',
|
408 |
+
# 2,
|
409 |
+
# 0.6,
|
410 |
+
# 0,
|
411 |
+
# 0,
|
412 |
# 6,
|
413 |
# 0.5,
|
414 |
# None,
|
415 |
# ],
|
416 |
+
# ]
|
417 |
+
# gr.Examples(examples=style_examples,
|
418 |
+
# label='Font style examples',
|
419 |
+
# inputs=[
|
420 |
+
# text_input,
|
421 |
+
# negative_prompt,
|
422 |
+
# num_segments,
|
423 |
+
# segment_threshold,
|
424 |
+
# inject_interval,
|
425 |
+
# inject_background,
|
426 |
+
# seed,
|
427 |
+
# color_guidance_weight,
|
428 |
+
# rich_text_input,
|
429 |
+
# ],
|
430 |
+
# outputs=[
|
431 |
+
# plaintext_result,
|
432 |
+
# richtext_result,
|
433 |
+
# segments,
|
434 |
+
# token_map,
|
435 |
+
# ],
|
436 |
+
# fn=generate,
|
437 |
+
# cache_examples=True,
|
438 |
+
# examples_per_page=20)
|
439 |
+
|
440 |
+
# with gr.Row():
|
441 |
+
# size_examples = [
|
442 |
# [
|
443 |
+
# '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "60px"}, "insert": "pineapple"}, {"insert": " pepperoni, and mushroom on the top"}]}',
|
444 |
+
# '',
|
445 |
+
# 5,
|
|
|
446 |
# 0.3,
|
447 |
+
# 0,
|
448 |
+
# 0,
|
449 |
+
# 3,
|
450 |
+
# 1,
|
451 |
+
# None,
|
452 |
+
# ],
|
453 |
+
# [
|
454 |
+
# '{"ops": [{"insert": "A pizza with pineapple, "}, {"attributes": {"size": "60px"}, "insert": "pepperoni"}, {"insert": ", and mushroom on the top"}]}',
|
455 |
+
# '',
|
456 |
+
# 5,
|
457 |
# 0.3,
|
458 |
+
# 0,
|
459 |
+
# 0,
|
460 |
+
# 3,
|
461 |
+
# 1,
|
462 |
# None,
|
463 |
# ],
|
464 |
# [
|
465 |
+
# '{"ops": [{"insert": "A pizza with pineapple, pepperoni, and "}, {"attributes": {"size": "60px"}, "insert": "mushroom"}, {"insert": " on the top"}]}',
|
466 |
# '',
|
467 |
+
# 5,
|
|
|
|
|
468 |
# 0.3,
|
469 |
+
# 0,
|
470 |
+
# 0,
|
471 |
+
# 3,
|
472 |
+
# 1,
|
473 |
# None,
|
474 |
# ],
|
475 |
# ]
|
476 |
+
# gr.Examples(examples=size_examples,
|
477 |
+
# label='Font size examples',
|
478 |
# inputs=[
|
479 |
# text_input,
|
480 |
# negative_prompt,
|
|
|
495 |
# fn=generate,
|
496 |
# cache_examples=True,
|
497 |
# examples_per_page=20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
generate_button.click(fn=lambda: gr.update(visible=False), inputs=None, outputs=share_row, queue=False).then(
|
499 |
fn=generate,
|
500 |
inputs=[
|
models/region_diffusion_xl.py
CHANGED
@@ -846,12 +846,16 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
|
846 |
# apply guidance
|
847 |
if use_guidance and t < text_format_dict['guidance_start_step']:
|
848 |
with torch.enable_grad():
|
|
|
|
|
849 |
if not latents.requires_grad:
|
850 |
latents.requires_grad = True
|
851 |
# import ipdb;ipdb.set_trace()
|
852 |
-
latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
|
|
|
853 |
latents_inp = latents_0 / self.vae.config.scaling_factor
|
854 |
-
imgs = self.vae.
|
|
|
855 |
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
856 |
loss_total = 0.
|
857 |
for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
|
@@ -863,6 +867,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
|
863 |
loss_total.backward()
|
864 |
latents = (
|
865 |
latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
|
|
|
866 |
|
867 |
# apply background injection
|
868 |
if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
|
@@ -1023,7 +1028,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
|
1023 |
PyTorch Forward hook to save outputs at each forward pass.
|
1024 |
"""
|
1025 |
if 'attn1' in name:
|
1026 |
-
modified_args = (args[0], self.self_attention_maps_cur[name])
|
1027 |
return modified_args
|
1028 |
# cross attention injection
|
1029 |
# elif 'attn2' in name:
|
@@ -1039,7 +1044,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
|
1039 |
PyTorch Forward hook to save outputs at each forward pass.
|
1040 |
"""
|
1041 |
modified_args = (args[0], args[1],
|
1042 |
-
self.self_attention_maps_cur[name])
|
1043 |
return modified_args
|
1044 |
for name, module in self.unet.named_modules():
|
1045 |
leaf_name = name.split('.')[-1]
|
@@ -1077,7 +1082,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
|
1077 |
# activations[name] = out[1][1].detach()
|
1078 |
else:
|
1079 |
assert out[1][1].shape[-1] != 77
|
1080 |
-
activations[name] = out[1][1].detach()
|
1081 |
|
1082 |
def save_resnet_activations(activations, name, module, inp, out):
|
1083 |
r"""
|
@@ -1087,7 +1092,7 @@ class RegionDiffusionXL(DiffusionPipeline, FromSingleFileMixin):
|
|
1087 |
# out[1] - residual hidden feature
|
1088 |
# import ipdb;ipdb.set_trace()
|
1089 |
assert out[1].shape[-1] == 64
|
1090 |
-
activations[name] = out[1].detach()
|
1091 |
attention_dict = collections.defaultdict(list)
|
1092 |
for name, module in self.unet.named_modules():
|
1093 |
leaf_name = name.split('.')[-1]
|
|
|
846 |
# apply guidance
|
847 |
if use_guidance and t < text_format_dict['guidance_start_step']:
|
848 |
with torch.enable_grad():
|
849 |
+
self.unet.to(device='cpu')
|
850 |
+
torch.cuda.empty_cache()
|
851 |
if not latents.requires_grad:
|
852 |
latents.requires_grad = True
|
853 |
# import ipdb;ipdb.set_trace()
|
854 |
+
# latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=latents.dtype)
|
855 |
+
latents_0 = self.predict_x0(latents, noise_pred, t).to(dtype=torch.bfloat16)
|
856 |
latents_inp = latents_0 / self.vae.config.scaling_factor
|
857 |
+
imgs = self.vae.to(dtype=latents_inp.dtype).decode(latents_inp).sample
|
858 |
+
# imgs = self.vae.decode(latents_inp.to(dtype=torch.float32)).sample
|
859 |
imgs = (imgs / 2 + 0.5).clamp(0, 1)
|
860 |
loss_total = 0.
|
861 |
for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
|
|
|
867 |
loss_total.backward()
|
868 |
latents = (
|
869 |
latents - latents.grad * text_format_dict['color_guidance_weight'] * text_format_dict['color_obj_atten_all']).detach().clone().to(dtype=prompt_embeds.dtype)
|
870 |
+
self.unet.to(device=latents.device)
|
871 |
|
872 |
# apply background injection
|
873 |
if i == int(inject_background * len(self.scheduler.timesteps)) and inject_background > 0:
|
|
|
1028 |
PyTorch Forward hook to save outputs at each forward pass.
|
1029 |
"""
|
1030 |
if 'attn1' in name:
|
1031 |
+
modified_args = (args[0], self.self_attention_maps_cur[name].to(args[0].device))
|
1032 |
return modified_args
|
1033 |
# cross attention injection
|
1034 |
# elif 'attn2' in name:
|
|
|
1044 |
PyTorch Forward hook to save outputs at each forward pass.
|
1045 |
"""
|
1046 |
modified_args = (args[0], args[1],
|
1047 |
+
self.self_attention_maps_cur[name].to(args[0].device))
|
1048 |
return modified_args
|
1049 |
for name, module in self.unet.named_modules():
|
1050 |
leaf_name = name.split('.')[-1]
|
|
|
1082 |
# activations[name] = out[1][1].detach()
|
1083 |
else:
|
1084 |
assert out[1][1].shape[-1] != 77
|
1085 |
+
activations[name] = out[1][1].detach().cpu()
|
1086 |
|
1087 |
def save_resnet_activations(activations, name, module, inp, out):
|
1088 |
r"""
|
|
|
1092 |
# out[1] - residual hidden feature
|
1093 |
# import ipdb;ipdb.set_trace()
|
1094 |
assert out[1].shape[-1] == 64
|
1095 |
+
activations[name] = out[1].detach().cpu()
|
1096 |
attention_dict = collections.defaultdict(list)
|
1097 |
for name, module in self.unet.named_modules():
|
1098 |
leaf_name = name.split('.')[-1]
|