Spaces:
Runtime error
Runtime error
Linoy Tsaban
commited on
Commit
•
8f7289c
1
Parent(s):
998e5bc
Update app.py
Browse files
app.py
CHANGED
@@ -132,7 +132,23 @@ def edit(input_image,
|
|
132 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
133 |
threshold_1, threshold_2, threshold_3,
|
134 |
do_reconstruction,
|
135 |
-
reconstruction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
|
138 |
editing_args = dict(
|
@@ -151,7 +167,7 @@ def edit(input_image,
|
|
151 |
num_inference_steps=steps,
|
152 |
use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
|
153 |
|
154 |
-
return sega_out.images[0], reconstruct_button.update(visible=True), do_reconstruction, reconstruction
|
155 |
|
156 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
157 |
|
@@ -159,9 +175,9 @@ def edit(input_image,
|
|
159 |
pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
160 |
reconstruction = gr.State(value=pure_ddpm_img)
|
161 |
do_reconstruction = False
|
162 |
-
return pure_ddpm_img, reconstruct_button.update(visible=False), do_reconstruction, reconstruction
|
163 |
|
164 |
-
return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction
|
165 |
|
166 |
|
167 |
def randomize_seed_fn(seed, randomize_seed):
|
@@ -635,21 +651,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
635 |
#add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
|
636 |
# outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
|
637 |
|
638 |
-
run_button.click(
|
639 |
-
fn=load_and_invert,
|
640 |
-
inputs=[input_image,
|
641 |
-
do_inversion,
|
642 |
-
seed, randomize_seed,
|
643 |
-
wts, zs,
|
644 |
-
src_prompt,
|
645 |
-
tar_prompt,
|
646 |
-
steps,
|
647 |
-
src_cfg_scale,
|
648 |
-
skip,
|
649 |
-
tar_cfg_scale
|
650 |
-
],
|
651 |
-
outputs=[wts, zs, do_inversion, inversion_progress],
|
652 |
-
).success(
|
653 |
fn=edit,
|
654 |
inputs=[input_image,
|
655 |
wts, zs,
|
@@ -661,10 +663,16 @@ with gr.Blocks(css="style.css") as demo:
|
|
661 |
guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
|
662 |
warmup_1, warmup_2, warmup_3,
|
663 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
664 |
-
threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction
|
|
|
|
|
|
|
|
|
|
|
|
|
665 |
|
666 |
],
|
667 |
-
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction])
|
668 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
669 |
|
670 |
|
|
|
132 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
133 |
threshold_1, threshold_2, threshold_3,
|
134 |
do_reconstruction,
|
135 |
+
reconstruction,
|
136 |
+
|
137 |
+
# for inversion in case it needs to be re computed (and avoid delay):
|
138 |
+
do_inversion,
|
139 |
+
seed,
|
140 |
+
randomize_seed,
|
141 |
+
src_prompt,
|
142 |
+
src_cfg_scale):
|
143 |
+
|
144 |
+
if do_inversion or randomize_seed:
|
145 |
+
x0 = load_512(input_image, device=device)
|
146 |
+
# invert and retrieve noise maps and latent
|
147 |
+
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
|
148 |
+
wts = gr.State(value=wts_tensor)
|
149 |
+
zs = gr.State(value=zs_tensor)
|
150 |
+
do_inversion = False
|
151 |
+
|
152 |
|
153 |
if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
|
154 |
editing_args = dict(
|
|
|
167 |
num_inference_steps=steps,
|
168 |
use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
|
169 |
|
170 |
+
return sega_out.images[0], reconstruct_button.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion
|
171 |
|
172 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
173 |
|
|
|
175 |
pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
176 |
reconstruction = gr.State(value=pure_ddpm_img)
|
177 |
do_reconstruction = False
|
178 |
+
return pure_ddpm_img, reconstruct_button.update(visible=False), do_reconstruction, reconstruction wts, zs, do_inversion
|
179 |
|
180 |
+
return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion
|
181 |
|
182 |
|
183 |
def randomize_seed_fn(seed, randomize_seed):
|
|
|
651 |
#add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
|
652 |
# outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
|
653 |
|
654 |
+
run_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
fn=edit,
|
656 |
inputs=[input_image,
|
657 |
wts, zs,
|
|
|
663 |
guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
|
664 |
warmup_1, warmup_2, warmup_3,
|
665 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
666 |
+
threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction,
|
667 |
+
do_inversion,
|
668 |
+
seed,
|
669 |
+
randomize_seed,
|
670 |
+
src_prompt,
|
671 |
+
src_cfg_scale
|
672 |
+
|
673 |
|
674 |
],
|
675 |
+
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs, do_inversion])
|
676 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
677 |
|
678 |
|