Linoy Tsaban commited on
Commit
7b3a214
1 Parent(s): 69cb2e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -87
app.py CHANGED
@@ -132,7 +132,6 @@ def invert_and_reconstruct(
132
  src_cfg_scale = 3.5,
133
  skip=36,
134
  tar_cfg_scale=15,
135
- # neg_guidance=False,
136
 
137
  ):
138
 
@@ -146,16 +145,11 @@ def invert_and_reconstruct(
146
  zs = gr.State(value=zs_tensor)
147
  do_inversion = False
148
 
149
- output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
150
 
151
- return output, wts, zs, do_inversion
 
152
 
153
- def update_sega_concept_table(edit_concept, neg_guidance, concepts_table):
154
- if edit_concept:
155
- new_rows = concepts_table.value.append([edit_concept, neg_guidance])
156
- new_concepts_table = gr.DataFrame(value=new_rows)
157
- return new_concepts_table
158
- return concepts_table
159
 
160
  def edit(input_image,
161
  wts, zs,
@@ -163,49 +157,29 @@ def edit(input_image,
163
  steps=100,
164
  skip=36,
165
  tar_cfg_scale=15,
166
- edit_concept="",
167
- sega_edit_guidance=10,
168
- warm_up=None,
169
- # neg_guidance=False,
 
170
 
171
  ):
172
 
173
  # SEGA
174
  # parse concepts and neg guidance
175
- edit_concepts = edit_concept.split(",")
176
- num_concepts = len(edit_concepts)
177
- neg_guidance =[]
178
- for edit_concept in edit_concepts:
179
- edit_concept=edit_concept.strip(" ")
180
- if edit_concept.startswith("-"):
181
- neg_guidance.append(True)
182
- else:
183
- neg_guidance.append(False)
184
- edit_concepts = [concept.strip("+|-") for concept in edit_concepts]
185
-
186
- # parse warm-up steps
187
- default_warm_up_steps = [1]*num_concepts
188
- if warm_up:
189
- digit_pattern = re.compile(r"^\d+$")
190
- warm_up_steps_str = warm_up.split(",")
191
- for i,num_steps in enumerate(warm_up_steps_str[:num_concepts]):
192
- if not digit_pattern.match(num_steps):
193
- raise gr.Error("Invalid value for warm-up steps, using 1 instead")
194
- else:
195
- default_warm_up_steps[i] = int(num_steps)
196
-
197
-
198
  editing_args = dict(
199
- editing_prompt = edit_concepts,
200
- reverse_editing_direction = neg_guidance,
201
- edit_warmup_steps=default_warm_up_steps,
202
- edit_guidance_scale=[sega_edit_guidance]*num_concepts,
203
- edit_threshold=[.95]*num_concepts,
204
  edit_momentum_scale=0.5,
205
- edit_mom_beta=0.6
 
206
  )
207
  latnets = wts.value[skip].expand(1, -1, -1, -1)
208
- sega_out = sem_pipe(prompt=tar_prompt,eta=1, latents=latnets, guidance_scale = tar_cfg_scale,
209
  num_images_per_prompt=1,
210
  num_inference_steps=steps,
211
  use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
@@ -246,7 +220,7 @@ with gr.Blocks(css='style.css') as demo:
246
 
247
  with gr.Row():
248
  input_image = gr.Image(label="Input Image", interactive=True)
249
- ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
250
  sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
251
  input_image.style(height=512, width=512)
252
  ddpm_edited_image.style(height=512, width=512)
@@ -254,31 +228,20 @@ with gr.Blocks(css='style.css') as demo:
254
 
255
  with gr.Row():
256
  tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True, placeholder="")
257
- # with gr.Accordion("SEGA Concepts", open=False, visible=True):
258
- # with gr.Column(scale=1):
259
- # edit_concept = gr.Textbox(lines=1, label="Enter SEGA Edit Concept", visible = True, interactive=True)
260
- # with gr.Column(scale=1):
261
- # neg_guidance = gr.Checkbox(label="Negative Guidance", value=False)
262
- # submit_concept = gr.Button(label="Add Concept")
263
- # concepts_table = gr.Dataframe(
264
- # headers=["Concepts", "Negative Guidance"],
265
- # datatype=["str", "bool"],
266
- # label="SEGA Concepts",
267
- # )
268
-
269
-
270
  with gr.Row():
271
  with gr.Column(scale=1, min_width=100):
272
- invert_button = gr.Button("Invert")
273
- with gr.Column(scale=1, min_width=100):
274
- edit_button = gr.Button("Edit")
275
 
276
  with gr.Accordion("Advanced Options", open=False):
277
  with gr.Tabs() as tabs:
278
  with gr.TabItem('SEGA Guidance', id=0):
279
  with gr.Row().style(mobile_collapse=False, equal_height=True):
280
- edit_1 = gr.Textbox(
281
- label="Edit Prompt 1",
282
  show_label=False,
283
  max_lines=1,
284
  placeholder="Enter your 1st edit prompt",
@@ -289,11 +252,12 @@ with gr.Blocks(css='style.css') as demo:
289
  )
290
  with gr.Group():
291
  with gr.Row().style(mobile_collapse=False, equal_height=True):
292
- rev_1 = gr.Checkbox(
293
  label='Negative Guidance')
294
  warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50, value=10, step=1, interactive=True)
295
- scale_1 = gr.Slider(label='Scale', minimum=1, maximum=10, value=5, step=0.25, interactive=True)
296
  threshold_1 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.95, steps=0.01, interactive=True)
 
297
  with gr.TabItem('DDPM Guidance', id=1):
298
  with gr.Row():
299
  with gr.Column():
@@ -305,17 +269,14 @@ with gr.Blocks(css='style.css') as demo:
305
  with gr.Column():
306
  skip = gr.Slider(minimum=0, maximum=40, value=36, label="Skip Steps", interactive=True)
307
  tar_cfg_scale = gr.Slider(minimum=7, maximum=18,value=15, label=f"Guidance Scale", interactive=True)
308
- sega_edit_guidance = gr.Slider(value=10, label=f"SEGA Edit Guidance Scale", interactive=True)
309
- warm_up = gr.Textbox(label=f"SEGA Warm-up Steps", interactive=True, placeholder="type #warm-up steps for each concpets (e.g. 2,7,5...")
310
 
311
-
312
- # neg_guidance = gr.Checkbox(label="SEGA Negative Guidance")
313
 
314
 
315
  # gr.Markdown(help_text)
316
 
317
 
318
- invert_button.click(
319
  fn = randomize_seed_fn,
320
  inputs = [seed, randomize_seed],
321
  outputs = [seed],
@@ -332,26 +293,26 @@ with gr.Blocks(css='style.css') as demo:
332
  skip,
333
  tar_cfg_scale,
334
  ],
335
- outputs=[ddpm_edited_image, wts, zs, do_inversion],
336
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
- # edit_button.click(
339
- # fn=edit,
340
- # inputs=[input_image,
341
- # wts, zs,
342
- # tar_prompt,
343
- # steps,
344
- # skip,
345
- # tar_cfg_scale,
346
- # edit_concept,
347
- # sega_edit_guidance,
348
- # warm_up,
349
- # # neg_guidance,
350
-
351
- # ],
352
- # outputs=[sega_edited_image],
353
 
354
- # )
355
 
356
  input_image.change(
357
  fn = reset_do_inversion,
 
132
  src_cfg_scale = 3.5,
133
  skip=36,
134
  tar_cfg_scale=15,
 
135
 
136
  ):
137
 
 
145
  zs = gr.State(value=zs_tensor)
146
  do_inversion = False
147
 
148
+ # output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
149
 
150
+ # return output, wts, zs, do_inversion
151
+ return wts, zs, do_inversion
152
 
 
 
 
 
 
 
153
 
154
  def edit(input_image,
155
  wts, zs,
 
157
  steps=100,
158
  skip=36,
159
  tar_cfg_scale=15,
160
+ edit_concept_1 = "",
161
+ guidnace_scale_1 = 10,
162
+ warmup_1 = 1,
163
+ neg_guidance_1 = False,
164
+ threshold_1 = 0.95
165
 
166
  ):
167
 
168
  # SEGA
169
  # parse concepts and neg guidance
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  editing_args = dict(
172
+ editing_prompt = [edit_concept_1],
173
+ reverse_editing_direction = [neg_guidance_1],
174
+ edit_warmup_steps=[warmup_1],
175
+ edit_guidance_scale=[guidnace_scale_1],
176
+ edit_threshold=[threshold_1],
177
  edit_momentum_scale=0.5,
178
+ edit_mom_beta=0.6,
179
+ eta=1,
180
  )
181
  latnets = wts.value[skip].expand(1, -1, -1, -1)
182
+ sega_out = sem_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
183
  num_images_per_prompt=1,
184
  num_inference_steps=steps,
185
  use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
 
220
 
221
  with gr.Row():
222
  input_image = gr.Image(label="Input Image", interactive=True)
223
+ # ddpm_edited_image = gr.Image(label=f"DDPM Reconstructed Image", interactive=False, visible=False)
224
  sega_edited_image = gr.Image(label=f"DDPM + SEGA Edited Image", interactive=False)
225
  input_image.style(height=512, width=512)
226
  ddpm_edited_image.style(height=512, width=512)
 
228
 
229
  with gr.Row():
230
  tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True, placeholder="")
231
+
232
+
 
 
 
 
 
 
 
 
 
 
 
233
  with gr.Row():
234
  with gr.Column(scale=1, min_width=100):
235
+ run_button = gr.Button("Run")
236
+ # with gr.Column(scale=1, min_width=100):
237
+ # edit_button = gr.Button("Edit")
238
 
239
  with gr.Accordion("Advanced Options", open=False):
240
  with gr.Tabs() as tabs:
241
  with gr.TabItem('SEGA Guidance', id=0):
242
  with gr.Row().style(mobile_collapse=False, equal_height=True):
243
+ edit_concept_1 = gr.Textbox(
244
+ label="Edit Concept",
245
  show_label=False,
246
  max_lines=1,
247
  placeholder="Enter your 1st edit prompt",
 
252
  )
253
  with gr.Group():
254
  with gr.Row().style(mobile_collapse=False, equal_height=True):
255
+ neg_guidance_1 = gr.Checkbox(
256
  label='Negative Guidance')
257
  warmup_1 = gr.Slider(label='Warmup', minimum=0, maximum=50, value=10, step=1, interactive=True)
258
+ guidnace_scale_1 = gr.Slider(label='Scale', minimum=1, maximum=10, value=5, step=0.25, interactive=True)
259
  threshold_1 = gr.Slider(label='Threshold', minimum=0.5, maximum=0.99, value=0.95, steps=0.01, interactive=True)
260
+
261
  with gr.TabItem('DDPM Guidance', id=1):
262
  with gr.Row():
263
  with gr.Column():
 
269
  with gr.Column():
270
  skip = gr.Slider(minimum=0, maximum=40, value=36, label="Skip Steps", interactive=True)
271
  tar_cfg_scale = gr.Slider(minimum=7, maximum=18,value=15, label=f"Guidance Scale", interactive=True)
 
 
272
 
273
+
 
274
 
275
 
276
  # gr.Markdown(help_text)
277
 
278
 
279
+ run_button.click(
280
  fn = randomize_seed_fn,
281
  inputs = [seed, randomize_seed],
282
  outputs = [seed],
 
293
  skip,
294
  tar_cfg_scale,
295
  ],
296
+ # outputs=[ddpm_edited_image, wts, zs, do_inversion],
297
+ outputs=[wts, zs, do_inversion],
298
+ ).success(
299
+ fn=edit,
300
+ inputs=[input_image,
301
+ wts, zs,
302
+ tar_prompt,
303
+ steps,
304
+ skip,
305
+ tar_cfg_scale,
306
+ edit_concept_1,
307
+ guidnace_scale_1,
308
+ warmup_1,
309
+ neg_guidance_1,
310
+ threshold_1
311
 
312
+ ],
313
+ outputs=[sega_edited_image],
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
+ )
316
 
317
  input_image.change(
318
  fn = reset_do_inversion,