Linoy Tsaban commited on
Commit
017df60
1 Parent(s): 2004272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -70
app.py CHANGED
@@ -47,7 +47,6 @@ def sample(wt, zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
47
 
48
  # load pipelines
49
  sd_model_id = "runwayml/stable-diffusion-v1-5"
50
- # sd_model_id = "stabilityai/stable-diffusion-2-base"
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
  sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
53
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
@@ -112,64 +111,49 @@ def get_example():
112
 
113
 
114
  def invert_and_reconstruct(
115
- input_image,
 
 
 
116
  src_prompt ="",
117
  tar_prompt="",
118
  steps=100,
119
- # src_cfg_scale,
120
  skip=36,
121
  tar_cfg_scale=15,
122
  # neg_guidance=False,
123
- seed =0
124
  ):
125
- offsets=(0,0,0,0)
126
  torch.manual_seed(seed)
127
  x0 = load_512(input_image, device=device)
128
 
 
 
 
 
 
 
129
 
130
- # invert
131
- # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
132
- wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
133
 
134
- latnets = wts[skip].expand(1, -1, -1, -1)
135
 
136
 
137
- #pure DDPM output
138
- pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
139
- cfg_scale_tar=tar_cfg_scale, skip=skip)
140
- # inversion_map['latnets'] = latnets
141
- # inversion_map['zs'] = zs
142
- # inversion_map['wts'] = wts
143
 
144
- return pure_ddpm_out
 
 
 
 
 
 
 
 
 
 
 
145
 
146
-
147
- def edit(input_image,
148
- src_prompt ="",
149
- tar_prompt="",
150
- steps=100,
151
- # src_cfg_scale,
152
- skip=36,
153
- tar_cfg_scale=15,
154
- edit_concept="",
155
- sega_edit_guidance=10,
156
- warm_up=None,
157
- # neg_guidance=False,
158
- seed =0,
159
  ):
160
- torch.manual_seed(seed)
161
- # if not bool(inversion_map):
162
- # raise gr.Error("Must invert before editing")
163
-
164
-
165
-
166
- x0 = load_512(input_image, device=device)
167
-
168
- # invert
169
- # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
170
- wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
171
-
172
- latnets = wts[skip].expand(1, -1, -1, -1)
173
 
174
  # SEGA
175
  # parse concepts and neg guidance
@@ -205,10 +189,11 @@ def edit(input_image,
205
  edit_momentum_scale=0.5,
206
  edit_mom_beta=0.6
207
  )
 
208
  sega_out = sem_pipe(prompt=tar_prompt,eta=1, latents=latnets, guidance_scale = tar_cfg_scale,
209
  num_images_per_prompt=1,
210
  num_inference_steps=steps,
211
- use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
212
  return sega_out.images[0]
213
 
214
  ########
@@ -230,8 +215,15 @@ For faster inference without waiting in queue, you may duplicate the space and u
230
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
231
  <p/>"""
232
  with gr.Blocks(css='style.css') as demo:
 
 
 
 
 
233
  gr.HTML(intro)
234
-
 
 
235
 
236
  with gr.Row():
237
  input_image = gr.Image(label="Input Image", interactive=True)
@@ -243,7 +235,7 @@ with gr.Blocks(css='style.css') as demo:
243
 
244
  with gr.Row():
245
  tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True, placeholder="")
246
- edit_concept = gr.Textbox(lines=1, label="SEGA Edit Concepts", visible = False, interactive=True)
247
 
248
  with gr.Row():
249
  with gr.Column(scale=1, min_width=100):
@@ -257,13 +249,13 @@ with gr.Blocks(css='style.css') as demo:
257
  #inversion
258
  src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="")
259
  steps = gr.Number(value=100, precision=0, label="Num Diffusion Steps", interactive=True)
260
- # src_cfg_scale = gr.Number(value=3.5, label=f"Source CFG", interactive=True)
261
-
262
- # reconstruction
263
- skip = gr.Slider(minimum=0, maximum=40, value=36, precision=0, label="Skip Steps", interactive=True)
264
- tar_cfg_scale = gr.Slider(minimum=7, maximum=18,value=15, label=f"Guidance Scale", interactive=True)
265
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
 
266
  with gr.Column():
 
 
 
267
  sega_edit_guidance = gr.Slider(value=10, label=f"SEGA Edit Guidance Scale", interactive=True)
268
  warm_up = gr.Textbox(label=f"SEGA Warm-up Steps", interactive=True, placeholder="type #warm-up steps for each concpets (e.g. 2,7,5...")
269
 
@@ -274,37 +266,49 @@ with gr.Blocks(css='style.css') as demo:
274
  # gr.Markdown(help_text)
275
 
276
  invert_button.click(
 
 
 
 
277
  fn=invert_and_reconstruct,
278
  inputs=[input_image,
279
- src_prompt,
280
- tar_prompt,
281
- steps,
282
- # src_cfg_scale,
283
- skip,
284
- tar_cfg_scale,
285
- # neg_guidance,
286
- seed
 
287
  ],
288
- outputs=[ddpm_edited_image],
289
  )
290
 
291
  edit_button.click(
292
  fn=edit,
293
  inputs=[input_image,
294
- src_prompt,
295
- tar_prompt,
296
- steps,
297
- # src_cfg_scale,
298
- skip,
299
- tar_cfg_scale,
300
- edit_concept,
301
- sega_edit_guidance,
302
- warm_up,
303
- # neg_guidance,
304
- seed
 
305
 
306
  ],
307
  outputs=[sega_edited_image],
 
 
 
 
 
 
308
  )
309
 
310
  gr.Examples(
 
47
 
48
  # load pipelines
49
  sd_model_id = "runwayml/stable-diffusion-v1-5"
 
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
  sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
52
  sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
 
111
 
112
 
113
  def invert_and_reconstruct(
114
+ input_image,
115
+ do_inversion,
116
+ wts, zs,
117
+ seed,
118
  src_prompt ="",
119
  tar_prompt="",
120
  steps=100,
121
+ src_cfg_scale = 3.5,
122
  skip=36,
123
  tar_cfg_scale=15,
124
  # neg_guidance=False,
125
+
126
  ):
 
127
  torch.manual_seed(seed)
128
  x0 = load_512(input_image, device=device)
129
 
130
+ if do_inversion:
131
+ # invert and retrieve noise maps and latent
132
+ zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
133
+ wts = gr.State(value=wts_tensor)
134
+ zs = gr.State(value=zs_tensor)
135
+ do_inversion = False
136
 
137
+ output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
 
 
138
 
139
+ return output, wts, zs, do_inversion
140
 
141
 
 
 
 
 
 
 
142
 
143
+ def edit(input_image,
144
+ do_inversion,
145
+ wts, zs, seed,
146
+ src_prompt ="",
147
+ tar_prompt="",
148
+ steps=100,
149
+ skip=36,
150
+ tar_cfg_scale=15,
151
+ edit_concept="",
152
+ sega_edit_guidance=10,
153
+ warm_up=None,
154
+ # neg_guidance=False,
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  # SEGA
159
  # parse concepts and neg guidance
 
189
  edit_momentum_scale=0.5,
190
  edit_mom_beta=0.6
191
  )
192
+ latnets = wts.value[skip].expand(1, -1, -1, -1)
193
  sega_out = sem_pipe(prompt=tar_prompt,eta=1, latents=latnets, guidance_scale = tar_cfg_scale,
194
  num_images_per_prompt=1,
195
  num_inference_steps=steps,
196
+ use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
197
  return sega_out.images[0]
198
 
199
  ########
 
215
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
216
  <p/>"""
217
  with gr.Blocks(css='style.css') as demo:
218
+
219
+ def reset_do_inversion():
220
+ do_inversion = True
221
+ return do_inversion
222
+
223
  gr.HTML(intro)
224
+ wts = gr.State()
225
+ zs = gr.State()
226
+ do_inversion = gr.State(value=True)
227
 
228
  with gr.Row():
229
  input_image = gr.Image(label="Input Image", interactive=True)
 
235
 
236
  with gr.Row():
237
  tar_prompt = gr.Textbox(lines=1, label="Target Prompt", interactive=True, placeholder="")
238
+ edit_concept = gr.Textbox(lines=1, label="SEGA Edit Concepts", visible = True, interactive=True)
239
 
240
  with gr.Row():
241
  with gr.Column(scale=1, min_width=100):
 
249
  #inversion
250
  src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="")
251
  steps = gr.Number(value=100, precision=0, label="Num Diffusion Steps", interactive=True)
252
+ src_cfg_scale = gr.Number(value=3.5, label=f"Source Guidance Scale", interactive=True)
 
 
 
 
253
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
254
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=True)
255
  with gr.Column():
256
+ # reconstruction
257
+ skip = gr.Slider(minimum=0, maximum=40, value=36, precision=0, label="Skip Steps", interactive=True)
258
+ tar_cfg_scale = gr.Slider(minimum=7, maximum=18,value=15, label=f"Guidance Scale", interactive=True)
259
  sega_edit_guidance = gr.Slider(value=10, label=f"SEGA Edit Guidance Scale", interactive=True)
260
  warm_up = gr.Textbox(label=f"SEGA Warm-up Steps", interactive=True, placeholder="type #warm-up steps for each concpets (e.g. 2,7,5...")
261
 
 
266
  # gr.Markdown(help_text)
267
 
268
  invert_button.click(
269
+ fn = randomize_seed_fn,
270
+ inputs = [seed, randomize_seed],
271
+ outputs = [seed]
272
+ ).then(
273
  fn=invert_and_reconstruct,
274
  inputs=[input_image,
275
+ do_inversion,
276
+ wts, zs,
277
+ seed,
278
+ src_prompt,
279
+ tar_prompt,
280
+ steps,
281
+ src_cfg_scale,
282
+ skip,
283
+ tar_cfg_scale,
284
  ],
285
+ outputs=[ddpm_edited_image, wts, zs, do_inversion],
286
  )
287
 
288
  edit_button.click(
289
  fn=edit,
290
  inputs=[input_image,
291
+ do_inversion,
292
+ wts, zs,
293
+ seed,
294
+ src_prompt,
295
+ tar_prompt,
296
+ steps,
297
+ skip,
298
+ tar_cfg_scale,
299
+ edit_concept,
300
+ sega_edit_guidance,
301
+ warm_up,
302
+ # neg_guidance,
303
 
304
  ],
305
  outputs=[sega_edited_image],
306
+
307
+ )
308
+
309
+ input_image.change(
310
+ fn = reset_do_inversion,
311
+ outputs = [do_inversion]
312
  )
313
 
314
  gr.Examples(