Linoy Tsaban commited on
Commit
a44c2bb
1 Parent(s): 7e3c69d

Update app.py

Browse files

clear fixed + reconstruct

Files changed (1) hide show
  1. app.py +203 -140
app.py CHANGED
@@ -71,39 +71,44 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
71
  return img
72
 
73
 
74
- def reconstruct(tar_prompt,
75
- tar_cfg_scale,
76
- skip,
77
- wts, zs,
78
  do_reconstruction,
79
- reconstruction, reconstruct_button, hide_reconstruct_button
 
80
  ):
81
 
82
- if do_reconstruction:
83
- reconstruction_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
84
- reconstruction = gr.State(value=reconstruction_img)
85
- do_reconstruction = False
86
- return reconstruction.value, reconstruction, do_reconstruction, reconstruct_button.update(visible=False), hide_reconstruct_button.update(visible=True)
 
 
 
 
 
87
 
88
-
89
  def load_and_invert(
90
- input_image,
91
  do_inversion,
92
  seed, randomize_seed,
93
- wts, zs,
94
- src_prompt ="",
95
- tar_prompt="",
96
  steps=100,
97
  src_cfg_scale = 3.5,
98
  skip=36,
99
- tar_cfg_scale=15,
100
  progress=gr.Progress(track_tqdm=True)
101
-
102
  ):
103
 
104
-
105
  x0 = load_512(input_image, device=device)
106
-
107
  if do_inversion or randomize_seed:
108
  # invert and retrieve noise maps and latent
109
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
@@ -111,43 +116,40 @@ def load_and_invert(
111
  zs = gr.State(value=zs_tensor)
112
  do_inversion = False
113
 
114
- inversion_progress = "Inversion compeleted!"
115
- return wts, zs, do_inversion, inversion_progress
116
 
117
  ## SEGA ##
118
-
119
  def edit(input_image,
120
- wts, zs,
121
- tar_prompt,
122
  steps,
123
  skip,
124
  tar_cfg_scale,
125
- reconstruct_button,
126
  edit_concept_1,edit_concept_2,edit_concept_3,
127
  guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
128
  warmup_1, warmup_2, warmup_3,
129
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
130
  threshold_1, threshold_2, threshold_3):
131
-
132
-
133
  editing_args = dict(
134
  editing_prompt = [edit_concept_1,edit_concept_2,edit_concept_3],
135
  reverse_editing_direction = [ neg_guidance_1, neg_guidance_2, neg_guidance_3,],
136
  edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
137
- edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
138
  edit_threshold=[threshold_1, threshold_2, threshold_3],
139
- edit_momentum_scale=0.3,
140
  edit_mom_beta=0.6,
141
  eta=1,)
142
-
143
  latnets = wts.value[skip].expand(1, -1, -1, -1)
144
  sega_out = sem_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
145
- num_images_per_prompt=1,
146
- num_inference_steps=steps,
147
  use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
148
-
149
- # return sega_out.images[0], reconstruct_button.update(visible=True)
150
- return sega_out.images[0]
151
 
152
 
153
 
@@ -215,7 +217,7 @@ def get_example():
215
 
216
  ########
217
  # demo #
218
- ########
219
 
220
  intro = """
221
  <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
@@ -234,67 +236,74 @@ For faster inference without waiting in queue, you may duplicate the space and u
234
 
235
  help_text = """
236
  - **Getting Started - edit images with DDPM X SEGA:**
237
-
238
- The are 3 general setting options you can play with -
239
-
240
  1. **Pure DDPM Edit -** Describe the desired edited output image in detail
241
- 2. **Pure SEGA Edit -** Keep the target prompt empty ***or*** with a description of the original image and add editing concepts for Semantic Gudiance editing
242
- 3. **Combined -** Describe the desired edited output image in detail and add additional SEGA editing concepts on top
243
  - **Getting Started - Tips**
244
-
245
  While the best approach depends on your editing objective and source image, we can layout a few guiding tips to use as a starting point -
246
-
247
  1. **DDPM** is usually more suited for scene/style changes and major subject changes (for example ) while **SEGA** allows for more fine grained control, changes are more delicate, more suited for adding details (for example facial expressions and attributes, subtle style modifications, object adding/removing)
248
- 2. The more you describe the scene in the target prompt (both the parts and details you wish to keep the same and those you wish to change), the better the result
249
- 3. **Combining DDPM Edit with SEGA -**
250
- Try dividing your editing objective to more significant scene/style/subject changes and detail adding/removing and more moderate changes. Then describe the major changes in a detailed target prompt and add the more fine grained details as SEGA concepts.
251
  4. **Reconstruction:** Using an empty source prompt + target prompt will lead to a perfect reconstruction
252
  - **Fidelity vs creativity**:
253
-
254
  Bigger values → more fidelity, smaller values → more creativity
255
-
256
- 1. `Skip Steps`
257
  2. `Warmup` (SEGA)
258
  3. `Threshold` (SEGA)
259
-
260
  Bigger values → more creativity, smaller values → more fidelity
261
-
262
  1. `Guidance Scale`
263
  2. `Concept Guidance Scale` (SEGA)
264
  """
265
 
266
  with gr.Blocks(css='style.css') as demo:
267
-
268
  def add_concept(sega_concepts_counter):
269
  if sega_concepts_counter == 1:
270
  return row2.update(visible=True), row3.update(visible=False), add_concept_button.update(visible=True), 2
271
  else:
272
  return row2.update(visible=True), row3.update(visible=True), add_concept_button.update(visible=False), 3
273
 
274
-
275
  def reset_do_inversion():
276
  do_inversion = True
277
  return do_inversion
278
 
 
 
 
 
279
 
280
  def update_inversion_progress_visibility(do_inversion):
281
  if do_inversion:
282
  return inversion_progress.update(visible=True)
283
  else:
284
  return inversion_progress.update(visible=False)
285
-
286
-
287
 
288
-
 
 
 
 
289
  gr.HTML(intro)
290
  wts = gr.State()
291
  zs = gr.State()
292
- do_inversion = gr.State(value=True)
293
  reconstruction = gr.State()
 
 
294
  sega_concepts_counter = gr.State(1)
295
 
296
-
297
-
298
  with gr.Row():
299
  input_image = gr.Image(label="Input Image", interactive=True)
300
  ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
@@ -304,7 +313,7 @@ with gr.Blocks(css='style.css') as demo:
304
  sega_edited_image.style(height=365, width=365)
305
 
306
  with gr.Row():
307
- inversion_progress = gr.Textbox(visible=False)
308
 
309
  with gr.Tabs() as tabs:
310
  with gr.TabItem('1. Describe the desired output', id=0):
@@ -321,34 +330,25 @@ with gr.Blocks(css='style.css') as demo:
321
  with gr.Row().style(mobile_collapse=False, equal_height=True):
322
  neg_guidance_1 = gr.Checkbox(
323
  label='Negative Guidance')
324
- warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50,
325
- value=DEFAULT_WARMUP_STEPS,
326
- step=1, interactive=True)
327
- guidnace_scale_1 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
328
- value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
329
  step=0.5, interactive=True)
330
- threshold_1 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99,
331
- value=DEFAULT_THRESHOLD, steps=0.01, interactive=True)
332
  edit_concept_1 = gr.Textbox(
333
  label="Edit Concept",
334
  show_label=False,
335
  max_lines=1,
336
  placeholder="Enter your 1st edit prompt",
337
  )
338
-
339
  # 2nd SEGA concept
340
  with gr.Row(visible=False) as row2:
341
  neg_guidance_2 = gr.Checkbox(
342
  label='Negative Guidance',visible=True)
343
- warmup_2 = gr.Slider(label='Warmup', minimum=0, maximum=50,
344
- value=DEFAULT_WARMUP_STEPS,
345
- step=1, interactive=True)
346
- guidnace_scale_2 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
347
- value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
348
  step=0.5, interactive=True)
349
- threshold_2 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99,
350
- value=DEFAULT_THRESHOLD,
351
- steps=0.01, interactive=True)
352
  edit_concept_2 = gr.Textbox(
353
  label="Edit Concept",
354
  show_label=False,
@@ -359,34 +359,30 @@ with gr.Blocks(css='style.css') as demo:
359
  with gr.Row(visible=False) as row3:
360
  neg_guidance_3 = gr.Checkbox(
361
  label='Negative Guidance',visible=True)
362
- warmup_3 = gr.Slider(label='Warmup', minimum=0, maximum=50,
363
- value=DEFAULT_WARMUP_STEPS, step=1,
364
- interactive=True)
365
- guidnace_scale_3 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
366
- value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
367
  step=0.5, interactive=True)
368
- threshold_3 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99,
369
- value=DEFAULT_THRESHOLD, steps=0.01,
370
- interactive=True)
371
  edit_concept_3 = gr.Textbox(
372
  label="Edit Concept",
373
  show_label=False,
374
  max_lines=1,
375
  placeholder="Enter your 3rd edit prompt",
376
  )
377
-
378
  with gr.Row().style(mobile_collapse=False, equal_height=True):
379
  add_concept_button = gr.Button("+")
380
 
381
-
382
  with gr.Row():
383
  run_button = gr.Button("Edit")
384
  reconstruct_button = gr.Button("Show Reconstruction", visible=False)
385
- undo_button = gr.Button("Undo", visible=False)
386
-
387
- clear_button = gr.ClearButton()
388
 
389
  with gr.Accordion("Advanced Options", open=False):
 
 
390
  with gr.Row():
391
  with gr.Column():
392
  src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="")
@@ -394,53 +390,73 @@ with gr.Blocks(css='style.css') as demo:
394
  src_cfg_scale = gr.Number(value=3.5, label=f"Source Guidance Scale", interactive=True)
395
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
396
  randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
397
- with gr.Column():
398
  skip = gr.Slider(minimum=0, maximum=60, value=36, label="Skip Steps", interactive=True)
399
- tar_cfg_scale = gr.Slider(minimum=7, maximum=30,value=15, label=f"Guidance Scale", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
 
402
-
403
 
404
 
405
  # with gr.Accordion("Help", open=False):
406
  # gr.Markdown(help_text)
407
-
408
  caption_button.click(
409
  fn = caption_image,
410
  inputs = [input_image],
411
  outputs = [tar_prompt]
412
  )
413
-
414
  add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
415
  outputs= [row2, row3, add_concept_button, sega_concepts_counter], queue = False)
416
-
417
- run_button.click(
418
- fn = randomize_seed_fn,
419
- inputs = [seed, randomize_seed],
420
- outputs = [seed],
421
- queue = False).then(fn = update_inversion_progress_visibility, inputs =[do_inversion], outputs=[inversion_progress],queue=False).then(
422
  fn=load_and_invert,
423
- inputs=[input_image,
424
  do_inversion,
425
  seed, randomize_seed,
426
- wts, zs,
427
- src_prompt,
428
- tar_prompt,
429
  steps,
430
  src_cfg_scale,
431
  skip,
432
- tar_cfg_scale
433
  ],
434
  outputs=[wts, zs, do_inversion, inversion_progress],
435
  ).then(fn = update_inversion_progress_visibility, inputs =[do_inversion], outputs=[inversion_progress],queue=False).success(
436
  fn=edit,
437
- inputs=[input_image,
438
- wts, zs,
439
- tar_prompt,
440
  steps,
441
  skip,
442
  tar_cfg_scale,
443
- reconstruct_button,
444
  edit_concept_1,edit_concept_2,edit_concept_3,
445
  guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
446
  warmup_1, warmup_2, warmup_3,
@@ -448,66 +464,112 @@ with gr.Blocks(css='style.css') as demo:
448
  threshold_1, threshold_2, threshold_3
449
 
450
  ],
451
- # outputs=[sega_edited_image, reconstruct_button]
452
- outputs=[sega_edited_image] )
453
 
454
 
455
 
456
  # Automatically start inverting upon input_image change
457
  input_image.change(
458
  fn = reset_do_inversion,
459
- outputs = [do_inversion],
460
- queue = False).then(fn = update_inversion_progress_visibility, inputs =[do_inversion], outputs=[inversion_progress],queue=False).then(
 
461
  fn=load_and_invert,
462
- inputs=[input_image,
463
  do_inversion,
464
  seed, randomize_seed,
465
- wts, zs,
466
- src_prompt,
467
- tar_prompt,
468
  steps,
469
  src_cfg_scale,
470
  skip,
471
- tar_cfg_scale,
472
  ],
473
  # outputs=[ddpm_edited_image, wts, zs, do_inversion],
474
  outputs=[wts, zs, do_inversion, inversion_progress],
475
- ).then(fn = update_inversion_progress_visibility, inputs =[do_inversion], outputs=[inversion_progress],queue=False)
 
 
 
 
 
 
476
 
477
-
478
- # Repeat inversion when these params are changed:
479
  src_prompt.change(
480
  fn = reset_do_inversion,
481
- outputs = [do_inversion], queue = False)
482
-
483
- steps.change(fn = reset_do_inversion,
484
- outputs = [do_inversion], queue = False)
485
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
 
487
- src_cfg_scale.change(fn = reset_do_inversion,
488
- outputs = [do_inversion], queue = False)
 
489
 
490
- components_to_clear = [input_image,ddpm_edited_image,sega_edited_image, do_inversion,
 
 
491
  src_prompt, steps, src_cfg_scale, seed,
492
- tar_prompt, skip, tar_cfg_scale,
493
  edit_concept_1, guidnace_scale_1,warmup_1, threshold_1, neg_guidance_1,
494
  edit_concept_2, guidnace_scale_2,warmup_2, threshold_2, neg_guidance_2,
495
- edit_concept_3, guidnace_scale_3,warmup_3, threshold_3, neg_guidance_3,
496
-
497
- ]
498
- clear_output_vals = [None, None, None, True,
499
  "", DEFAULT_DIFFUSION_STEPS, DEFAULT_SOURCE_GUIDANCE_SCALE, DEFAULT_SEED,
500
- "", DEFAULT_SKIP_STEPS, DEFAULT_TARGET_GUIDANCE_SCALE,
501
  "", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE, DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE,
502
  "", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE, DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE,
503
  "", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE, DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE,
504
- ]
505
- clear_button.click(lambda:clear_output_vals, outputs =components_to_clear)
506
-
 
 
 
 
 
 
 
 
 
 
 
507
 
 
 
 
 
 
 
508
  # gr.Examples(
509
- # label='Examples',
510
- # examples=get_example(),
511
  # inputs=[input_image, src_prompt, tar_prompt, steps,
512
  # # src_cfg_scale,
513
  # skip,
@@ -527,6 +589,7 @@ with gr.Blocks(css='style.css') as demo:
527
 
528
 
529
 
 
530
  demo.queue()
531
  demo.launch(share=False)
532
 
 
71
  return img
72
 
73
 
74
+ def reconstruct(tar_prompt,
75
+ tar_cfg_scale,
76
+ skip,
77
+ wts, zs,
78
  do_reconstruction,
79
+ reconstruction,
80
+ reconstruct_button
81
  ):
82
 
83
+ if reconstruct_button == "Hide Reconstruction":
84
+ return reconstruction.value, reconstruction, ddpm_edited_image.update(visible=False), do_reconstruction, "Show Reconstruction"
85
+
86
+ else:
87
+ if do_reconstruction:
88
+ reconstruction_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
89
+ reconstruction = gr.State(value=reconstruction_img)
90
+ do_reconstruction = False
91
+ return reconstruction.value, reconstruction, ddpm_edited_image.update(visible=True), do_reconstruction, "Hide Reconstruction"
92
+
93
 
 
94
  def load_and_invert(
95
+ input_image,
96
  do_inversion,
97
  seed, randomize_seed,
98
+ wts, zs,
99
+ src_prompt ="",
100
+ tar_prompt="",
101
  steps=100,
102
  src_cfg_scale = 3.5,
103
  skip=36,
104
+ tar_cfg_scale=15,
105
  progress=gr.Progress(track_tqdm=True)
106
+
107
  ):
108
 
109
+
110
  x0 = load_512(input_image, device=device)
111
+
112
  if do_inversion or randomize_seed:
113
  # invert and retrieve noise maps and latent
114
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
 
116
  zs = gr.State(value=zs_tensor)
117
  do_inversion = False
118
 
119
+ return wts, zs, do_inversion, inversion_progress.update(visible=False)
 
120
 
121
  ## SEGA ##
122
+
123
  def edit(input_image,
124
+ wts, zs,
125
+ tar_prompt,
126
  steps,
127
  skip,
128
  tar_cfg_scale,
 
129
  edit_concept_1,edit_concept_2,edit_concept_3,
130
  guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
131
  warmup_1, warmup_2, warmup_3,
132
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
133
  threshold_1, threshold_2, threshold_3):
134
+
135
+
136
  editing_args = dict(
137
  editing_prompt = [edit_concept_1,edit_concept_2,edit_concept_3],
138
  reverse_editing_direction = [ neg_guidance_1, neg_guidance_2, neg_guidance_3,],
139
  edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
140
+ edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
141
  edit_threshold=[threshold_1, threshold_2, threshold_3],
142
+ edit_momentum_scale=0.3,
143
  edit_mom_beta=0.6,
144
  eta=1,)
145
+
146
  latnets = wts.value[skip].expand(1, -1, -1, -1)
147
  sega_out = sem_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
148
+ num_images_per_prompt=1,
149
+ num_inference_steps=steps,
150
  use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
151
+
152
+ return sega_out.images[0], reconstruct_button.update(visible=True)
 
153
 
154
 
155
 
 
217
 
218
  ########
219
  # demo #
220
+ ########
221
 
222
  intro = """
223
  <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
 
236
 
237
  help_text = """
238
  - **Getting Started - edit images with DDPM X SEGA:**
239
+
240
+ The are 3 general setting options you can play with -
241
+
242
  1. **Pure DDPM Edit -** Describe the desired edited output image in detail
243
+ 2. **Pure SEGA Edit -** Keep the target prompt empty ***or*** with a description of the original image and add editing concepts for Semantic Gudiance editing
244
+ 3. **Combined -** Describe the desired edited output image in detail and add additional SEGA editing concepts on top
245
  - **Getting Started - Tips**
246
+
247
  While the best approach depends on your editing objective and source image, we can layout a few guiding tips to use as a starting point -
248
+
249
  1. **DDPM** is usually more suited for scene/style changes and major subject changes (for example ) while **SEGA** allows for more fine grained control, changes are more delicate, more suited for adding details (for example facial expressions and attributes, subtle style modifications, object adding/removing)
250
+ 2. The more you describe the scene in the target prompt (both the parts and details you wish to keep the same and those you wish to change), the better the result
251
+ 3. **Combining DDPM Edit with SEGA -**
252
+ Try dividing your editing objective to more significant scene/style/subject changes and detail adding/removing and more moderate changes. Then describe the major changes in a detailed target prompt and add the more fine grained details as SEGA concepts.
253
  4. **Reconstruction:** Using an empty source prompt + target prompt will lead to a perfect reconstruction
254
  - **Fidelity vs creativity**:
255
+
256
  Bigger values → more fidelity, smaller values → more creativity
257
+
258
+ 1. `Skip Steps`
259
  2. `Warmup` (SEGA)
260
  3. `Threshold` (SEGA)
261
+
262
  Bigger values → more creativity, smaller values → more fidelity
263
+
264
  1. `Guidance Scale`
265
  2. `Concept Guidance Scale` (SEGA)
266
  """
267
 
268
  with gr.Blocks(css='style.css') as demo:
269
+
270
  def add_concept(sega_concepts_counter):
271
  if sega_concepts_counter == 1:
272
  return row2.update(visible=True), row3.update(visible=False), add_concept_button.update(visible=True), 2
273
  else:
274
  return row2.update(visible=True), row3.update(visible=True), add_concept_button.update(visible=False), 3
275
 
276
+
277
  def reset_do_inversion():
278
  do_inversion = True
279
  return do_inversion
280
 
281
+ def reset_do_reconstruction():
282
+ do_reconstruction = True
283
+ return do_reconstruction
284
+
285
 
286
  def update_inversion_progress_visibility(do_inversion):
287
  if do_inversion:
288
  return inversion_progress.update(visible=True)
289
  else:
290
  return inversion_progress.update(visible=False)
 
 
291
 
292
+
293
+
294
+
295
+
296
+
297
  gr.HTML(intro)
298
  wts = gr.State()
299
  zs = gr.State()
 
300
  reconstruction = gr.State()
301
+ do_inversion = gr.State(value=True)
302
+ do_reconstruction = gr.State(value=True)
303
  sega_concepts_counter = gr.State(1)
304
 
305
+
306
+
307
  with gr.Row():
308
  input_image = gr.Image(label="Input Image", interactive=True)
309
  ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
 
313
  sega_edited_image.style(height=365, width=365)
314
 
315
  with gr.Row():
316
+ inversion_progress = gr.Textbox(visible=False, label="Inversion progress")
317
 
318
  with gr.Tabs() as tabs:
319
  with gr.TabItem('1. Describe the desired output', id=0):
 
330
  with gr.Row().style(mobile_collapse=False, equal_height=True):
331
  neg_guidance_1 = gr.Checkbox(
332
  label='Negative Guidance')
333
+
334
+ guidnace_scale_1 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
335
+ value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
 
 
336
  step=0.5, interactive=True)
337
+
 
338
  edit_concept_1 = gr.Textbox(
339
  label="Edit Concept",
340
  show_label=False,
341
  max_lines=1,
342
  placeholder="Enter your 1st edit prompt",
343
  )
344
+
345
  # 2nd SEGA concept
346
  with gr.Row(visible=False) as row2:
347
  neg_guidance_2 = gr.Checkbox(
348
  label='Negative Guidance',visible=True)
349
+ guidnace_scale_2 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
350
+ value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
 
 
 
351
  step=0.5, interactive=True)
 
 
 
352
  edit_concept_2 = gr.Textbox(
353
  label="Edit Concept",
354
  show_label=False,
 
359
  with gr.Row(visible=False) as row3:
360
  neg_guidance_3 = gr.Checkbox(
361
  label='Negative Guidance',visible=True)
362
+
363
+ guidnace_scale_3 = gr.Slider(label='Concept Guidance Scale', minimum=1, maximum=30,
364
+ value=DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE,
 
 
365
  step=0.5, interactive=True)
366
+
 
 
367
  edit_concept_3 = gr.Textbox(
368
  label="Edit Concept",
369
  show_label=False,
370
  max_lines=1,
371
  placeholder="Enter your 3rd edit prompt",
372
  )
373
+
374
  with gr.Row().style(mobile_collapse=False, equal_height=True):
375
  add_concept_button = gr.Button("+")
376
 
377
+
378
  with gr.Row():
379
  run_button = gr.Button("Edit")
380
  reconstruct_button = gr.Button("Show Reconstruction", visible=False)
381
+ clear_button = gr.Button("Clear")
 
 
382
 
383
  with gr.Accordion("Advanced Options", open=False):
384
+ with gr.Tabs() as tabs:
385
+ with gr.TabItem('General options', id=2):
386
  with gr.Row():
387
  with gr.Column():
388
  src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="")
 
390
  src_cfg_scale = gr.Number(value=3.5, label=f"Source Guidance Scale", interactive=True)
391
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
392
  randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
393
+ with gr.Column():
394
  skip = gr.Slider(minimum=0, maximum=60, value=36, label="Skip Steps", interactive=True)
395
+ tar_cfg_scale = gr.Slider(minimum=7, maximum=30,value=15, label=f"Guidance Scale", interactive=True)
396
+ with gr.TabItem('SEGA options', id=3):
397
+ # 1st SEGA concept
398
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
399
+ warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50,
400
+ value=DEFAULT_WARMUP_STEPS,
401
+ step=1, interactive=True)
402
+ threshold_1 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99,
403
+ value=DEFAULT_THRESHOLD, steps=0.01, interactive=True)
404
+
405
+ # 2nd SEGA concept
406
+ with gr.Row(visible=False) as row2_advanced:
407
+ warmup_2 = gr.Slider(label='Warmup', minimum=0, maximum=50,
408
+ value=DEFAULT_WARMUP_STEPS,
409
+ step=1, interactive=True)
410
+ threshold_2 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99,
411
+ value=DEFAULT_THRESHOLD,
412
+ steps=0.01, interactive=True)
413
+ # 3rd SEGA concept
414
+ with gr.Row(visible=False) as row3_advanced:
415
+ warmup_3 = gr.Slider(label='Warmup', minimum=0, maximum=50,
416
+ value=DEFAULT_WARMUP_STEPS, step=1,
417
+ interactive=True)
418
+ threshold_3 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99,
419
+ value=DEFAULT_THRESHOLD, steps=0.01,
420
+ interactive=True)
421
+
422
 
423
 
 
424
 
425
 
426
  # with gr.Accordion("Help", open=False):
427
  # gr.Markdown(help_text)
428
+
429
  caption_button.click(
430
  fn = caption_image,
431
  inputs = [input_image],
432
  outputs = [tar_prompt]
433
  )
434
+
435
  add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
436
  outputs= [row2, row3, add_concept_button, sega_concepts_counter], queue = False)
437
+
438
+ run_button.click(fn = update_inversion_progress_visibility, inputs =[do_inversion], outputs=[inversion_progress],queue=False).then(
 
 
 
 
439
  fn=load_and_invert,
440
+ inputs=[input_image,
441
  do_inversion,
442
  seed, randomize_seed,
443
+ wts, zs,
444
+ src_prompt,
445
+ tar_prompt,
446
  steps,
447
  src_cfg_scale,
448
  skip,
449
+ tar_cfg_scale
450
  ],
451
  outputs=[wts, zs, do_inversion, inversion_progress],
452
  ).then(fn = update_inversion_progress_visibility, inputs =[do_inversion], outputs=[inversion_progress],queue=False).success(
453
  fn=edit,
454
+ inputs=[input_image,
455
+ wts, zs,
456
+ tar_prompt,
457
  steps,
458
  skip,
459
  tar_cfg_scale,
 
460
  edit_concept_1,edit_concept_2,edit_concept_3,
461
  guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
462
  warmup_1, warmup_2, warmup_3,
 
464
  threshold_1, threshold_2, threshold_3
465
 
466
  ],
467
+ outputs=[sega_edited_image, reconstruct_button])
 
468
 
469
 
470
 
471
  # Automatically start inverting upon input_image change
472
  input_image.change(
473
  fn = reset_do_inversion,
474
+ outputs = [do_inversion],
475
+ queue = False).then(fn = update_inversion_progress_visibility, inputs =[do_inversion],
476
+ outputs=[inversion_progress],queue=False).then(
477
  fn=load_and_invert,
478
+ inputs=[input_image,
479
  do_inversion,
480
  seed, randomize_seed,
481
+ wts, zs,
482
+ src_prompt,
483
+ tar_prompt,
484
  steps,
485
  src_cfg_scale,
486
  skip,
487
+ tar_cfg_scale,
488
  ],
489
  # outputs=[ddpm_edited_image, wts, zs, do_inversion],
490
  outputs=[wts, zs, do_inversion, inversion_progress],
491
+ ).then(fn = update_inversion_progress_visibility, inputs =[do_inversion],
492
+ outputs=[inversion_progress],queue=False).then(
493
+ lambda: reconstruct_button.update(visible=False),
494
+ outputs=[reconstruct_button]).then(
495
+ fn = reset_do_reconstruction,
496
+ outputs = [do_reconstruction],
497
+ queue = False)
498
 
499
+
500
+ # Repeat inversion (and reconstruction) when these params are changed:
501
  src_prompt.change(
502
  fn = reset_do_inversion,
503
+ outputs = [do_inversion], queue = False).then(
504
+ fn = reset_do_reconstruction,
505
+ outputs = [do_reconstruction], queue = False)
506
+
507
+ steps.change(
508
+ fn = reset_do_inversion,
509
+ outputs = [do_inversion], queue = False).then(
510
+ fn = reset_do_reconstruction,
511
+ outputs = [do_reconstruction], queue = False)
512
+
513
+
514
+ src_cfg_scale.change(
515
+ fn = reset_do_inversion,
516
+ outputs = [do_inversion], queue = False).then(
517
+ fn = reset_do_reconstruction,
518
+ outputs = [do_reconstruction], queue = False)
519
+
520
+ # Repeat only reconstruction these params are changed:
521
+
522
+ tar_prompt.change(
523
+ fn = reset_do_reconstruction,
524
+ outputs = [do_reconstruction], queue = False)
525
+
526
+ tar_cfg_scale.change(
527
+ fn = reset_do_reconstruction,
528
+ outputs = [do_reconstruction], queue = False)
529
 
530
+ skip.change(
531
+ fn = reset_do_reconstruction,
532
+ outputs = [do_reconstruction], queue = False)
533
 
534
+
535
+
536
+ clear_components = [input_image,ddpm_edited_image,ddpm_edited_image,sega_edited_image, do_inversion,
537
  src_prompt, steps, src_cfg_scale, seed,
538
+ tar_prompt, skip, tar_cfg_scale, reconstruct_button,reconstruct_button,
539
  edit_concept_1, guidnace_scale_1,warmup_1, threshold_1, neg_guidance_1,
540
  edit_concept_2, guidnace_scale_2,warmup_2, threshold_2, neg_guidance_2,
541
+ edit_concept_3, guidnace_scale_3,warmup_3, threshold_3, neg_guidance_3,]
542
+
543
+ clear_components_output_vals = [None, None,ddpm_edited_image.update(visible=False), None, True,
 
544
  "", DEFAULT_DIFFUSION_STEPS, DEFAULT_SOURCE_GUIDANCE_SCALE, DEFAULT_SEED,
545
+ "", DEFAULT_SKIP_STEPS, DEFAULT_TARGET_GUIDANCE_SCALE, reconstruct_button.update(value="Show Reconstruction"),reconstruct_button.update(visible=False),
546
  "", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE, DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE,
547
  "", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE, DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE,
548
  "", DEFAULT_SEGA_CONCEPT_GUIDANCE_SCALE, DEFAULT_WARMUP_STEPS, DEFAULT_THRESHOLD, DEFAULT_NEGATIVE_GUIDANCE,
549
+ ]
550
+
551
+
552
+ clear_button.click(lambda: clear_components_output_vals, outputs =clear_components)
553
+
554
+ reconstruct_button.click(lambda: ddpm_edited_image.update(visible=True), outputs=[ddpm_edited_image]).then(fn = reconstruct,
555
+ inputs = [tar_prompt,
556
+ tar_cfg_scale,
557
+ skip,
558
+ wts, zs,
559
+ do_reconstruction,
560
+ reconstruction,
561
+ reconstruct_button],
562
+ outputs = [ddpm_edited_image,reconstruction, ddpm_edited_image, do_reconstruction, reconstruct_button])
563
 
564
+ randomize_seed.change(
565
+ fn = randomize_seed_fn,
566
+ inputs = [seed, randomize_seed],
567
+ outputs = [seed],
568
+ queue = False)
569
+
570
  # gr.Examples(
571
+ # label='Examples',
572
+ # examples=get_example(),
573
  # inputs=[input_image, src_prompt, tar_prompt, steps,
574
  # # src_cfg_scale,
575
  # skip,
 
589
 
590
 
591
 
592
+
593
  demo.queue()
594
  demo.launch(share=False)
595