Linoy Tsaban commited on
Commit
6a5a59b
1 Parent(s): af56f98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -20
app.py CHANGED
@@ -14,12 +14,6 @@ import re
14
 
15
 
16
 
17
- def randomize_seed_fn(seed, randomize_seed):
18
- if randomize_seed:
19
- seed = random.randint(0, np.iinfo(np.int32).max)
20
- torch.manual_seed(seed)
21
- return seed
22
-
23
 
24
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
25
 
@@ -116,8 +110,29 @@ def get_example():
116
  ]]
117
  return case
118
 
 
 
 
 
 
 
 
 
119
 
120
- def invert_and_reconstruct(
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  input_image,
122
  do_inversion,
123
  seed, randomize_seed,
@@ -127,7 +142,7 @@ def invert_and_reconstruct(
127
  steps=100,
128
  src_cfg_scale = 3.5,
129
  skip=36,
130
- tar_cfg_scale=15,
131
 
132
  ):
133
 
@@ -140,10 +155,7 @@ def invert_and_reconstruct(
140
  wts = gr.State(value=wts_tensor)
141
  zs = gr.State(value=zs_tensor)
142
  do_inversion = False
143
-
144
- # output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
145
-
146
- # return output, wts, zs, do_inversion
147
  return wts, zs, do_inversion
148
 
149
 
@@ -244,7 +256,10 @@ with gr.Blocks(css='style.css') as demo:
244
  else:
245
  return row2.update(visible=True), row3.update(visible=True), plus.update(visible=False), 3
246
 
247
-
 
 
 
248
  def reset_do_inversion():
249
  do_inversion = True
250
  return do_inversion
@@ -255,15 +270,16 @@ with gr.Blocks(css='style.css') as demo:
255
  zs = gr.State()
256
  do_inversion = gr.State(value=True)
257
  sega_concepts_counter = gr.State(1)
 
258
 
259
 
260
 
261
  with gr.Row():
262
  input_image = gr.Image(label="Input Image", interactive=True)
263
- # ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
264
  sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
265
  input_image.style(height=365, width=365)
266
- # ddpm_edited_image.style(height=512, width=512)
267
  sega_edited_image.style(height=365, width=365)
268
 
269
  with gr.Tabs() as tabs:
@@ -322,12 +338,13 @@ with gr.Blocks(css='style.css') as demo:
322
  )
323
 
324
  with gr.Row().style(mobile_collapse=False, equal_height=True):
325
- plus = gr.Button("+")
326
 
327
 
328
  with gr.Row():
329
  with gr.Column(scale=1, min_width=100):
330
  run_button = gr.Button("Run")
 
331
  # with gr.Column(scale=1, min_width=100):
332
  # edit_button = gr.Button("Edit")
333
 
@@ -350,16 +367,25 @@ with gr.Blocks(css='style.css') as demo:
350
 
351
 
352
 
353
- plus.click(fn = add_concept, inputs=sega_concepts_counter,
354
  outputs= [row2, row3, plus, sega_concepts_counter], queue = False)
355
 
 
 
 
 
 
 
 
 
 
356
 
357
  run_button.click(
358
  fn = randomize_seed_fn,
359
  inputs = [seed, randomize_seed],
360
  outputs = [seed],
361
  queue = False).then(
362
- fn=invert_and_reconstruct,
363
  inputs=[input_image,
364
  do_inversion,
365
  seed, randomize_seed,
@@ -369,10 +395,10 @@ with gr.Blocks(css='style.css') as demo:
369
  steps,
370
  src_cfg_scale,
371
  skip,
372
- tar_cfg_scale,
373
  ],
374
- # outputs=[ddpm_edited_image, wts, zs, do_inversion],
375
  outputs=[wts, zs, do_inversion],
 
376
  ).success(
377
  fn=edit,
378
  inputs=[input_image,
@@ -389,8 +415,17 @@ with gr.Blocks(css='style.css') as demo:
389
 
390
  ],
391
  outputs=[sega_edited_image],
 
 
 
392
  )
393
 
 
 
 
 
 
 
394
  # Automatically start inverting upon input_image change
395
  input_image.change(
396
  fn = reset_do_inversion,
 
14
 
15
 
16
 
 
 
 
 
 
 
17
 
18
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
19
 
 
110
  ]]
111
  return case
112
 
113
+ def randomize_seed_fn(seed, randomize_seed):
114
+ if randomize_seed:
115
+ seed = random.randint(0, np.iinfo(np.int32).max)
116
+ torch.manual_seed(seed)
117
+ return seed
118
+
119
+
120
+
121
 
122
+ def reconstruct(tar_prompt,
123
+ tar_cfg_scale,
124
+ skip,
125
+ wts, zs,
126
+ # do_reconstruction,
127
+ # reconstruction
128
+ )
129
+
130
+ ):
131
+ # if do_reconstruction:
132
+ reconstruction = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
133
+ return reconstruction
134
+
135
+ def load_and_invert(
136
  input_image,
137
  do_inversion,
138
  seed, randomize_seed,
 
142
  steps=100,
143
  src_cfg_scale = 3.5,
144
  skip=36,
145
+ tar_cfg_scale=15
146
 
147
  ):
148
 
 
155
  wts = gr.State(value=wts_tensor)
156
  zs = gr.State(value=zs_tensor)
157
  do_inversion = False
158
+
 
 
 
159
  return wts, zs, do_inversion
160
 
161
 
 
256
  else:
257
  return row2.update(visible=True), row3.update(visible=True), plus.update(visible=False), 3
258
 
259
+ def show_reconstruction_option():
260
+ return reconstruct_button.update(visible=True)
261
+
262
+
263
  def reset_do_inversion():
264
  do_inversion = True
265
  return do_inversion
 
270
  zs = gr.State()
271
  do_inversion = gr.State(value=True)
272
  sega_concepts_counter = gr.State(1)
273
+ # reconstruction = gr.State()
274
 
275
 
276
 
277
  with gr.Row():
278
  input_image = gr.Image(label="Input Image", interactive=True)
279
+ ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
280
  sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
281
  input_image.style(height=365, width=365)
282
+ ddpm_edited_image.style(height=512, width=512)
283
  sega_edited_image.style(height=365, width=365)
284
 
285
  with gr.Tabs() as tabs:
 
338
  )
339
 
340
  with gr.Row().style(mobile_collapse=False, equal_height=True):
341
+ add_concept_button = gr.Button("+")
342
 
343
 
344
  with gr.Row():
345
  with gr.Column(scale=1, min_width=100):
346
  run_button = gr.Button("Run")
347
+ reconstruct_button = gr.Button("Show me the reconstruction")
348
  # with gr.Column(scale=1, min_width=100):
349
  # edit_button = gr.Button("Edit")
350
 
 
367
 
368
 
369
 
370
+ add_concept_button.click(fn = add_concept, inputs=sega_concepts_counter,
371
  outputs= [row2, row3, plus, sega_concepts_counter], queue = False)
372
 
373
+ reconstruct_button.click(
374
+ fn = reconstruct,
375
+ inputs = [tar_prompt,
376
+ tar_cfg_scale,
377
+ skip,
378
+ wts, zs]
379
+ outputs = [ddpm_edited_image]
380
+ )
381
+
382
 
383
  run_button.click(
384
  fn = randomize_seed_fn,
385
  inputs = [seed, randomize_seed],
386
  outputs = [seed],
387
  queue = False).then(
388
+ fn=load_and_invert,
389
  inputs=[input_image,
390
  do_inversion,
391
  seed, randomize_seed,
 
395
  steps,
396
  src_cfg_scale,
397
  skip,
398
+ tar_cfg_scale
399
  ],
 
400
  outputs=[wts, zs, do_inversion],
401
+
402
  ).success(
403
  fn=edit,
404
  inputs=[input_image,
 
415
 
416
  ],
417
  outputs=[sega_edited_image],
418
+ ).success(
419
+ fn = show_reconstruction_option,
420
+ outputs = [reconstruct_button]
421
  )
422
 
423
+ reconstruct_button.click(
424
+ fn =
425
+ )
426
+
427
+
428
+
429
  # Automatically start inverting upon input_image change
430
  input_image.change(
431
  fn = reset_do_inversion,