czl commited on
Commit
70ca747
1 Parent(s): 98b6f69

update layout

Browse files
Files changed (1) hide show
  1. app.py +196 -162
app.py CHANGED
@@ -50,11 +50,43 @@ def infer(
50
  ):
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  if randomize_seed:
54
  seed = random.randint(0, MAX_SEED)
55
  prompts = [prompt1, prompt2]
56
  generator = torch.Generator().manual_seed(seed)
57
- print(seed)
58
  interpolated_prompt_embeds, prompt_metadata = synth.interpolatePrompts(
59
  prompts,
60
  pipe,
@@ -65,18 +97,16 @@ def infer(
65
  )
66
  negative_prompts = [negative_prompt, negative_prompt]
67
  if negative_prompts != ["", ""]:
68
- interpolated_negative_prompts_embeds, negative_prompt_metadata = (
69
- synth.interpolatePrompts(
70
- negative_prompts,
71
- pipe,
72
- num_interpolation_steps,
73
- sample_mid_interpolation,
74
- remove_n_middle=remove_n_middle,
75
- device=device,
76
- )
77
  )
78
  else:
79
- interpolated_negative_prompts_embeds, negative_prompt_metadata = [None] * len(
80
  interpolated_prompt_embeds
81
  ), None
82
 
@@ -129,7 +159,7 @@ def infer(
129
  * 100
130
  )
131
 
132
- return image, seed, ssim_score, cosine_sim
133
 
134
 
135
  examples1 = [
@@ -141,12 +171,6 @@ examples2 = [
141
  "A photo of a beagle, a type of dog",
142
  ]
143
 
144
- css = """
145
- #col-container {
146
- margin: 0 auto;
147
- }
148
- """
149
-
150
 
151
  def update_steps(total_steps, interpolation_step):
152
  if interpolation_step > total_steps:
@@ -154,181 +178,191 @@ def update_steps(total_steps, interpolation_step):
154
  return gr.update(maximum=total_steps // 2)
155
 
156
 
 
 
 
 
 
 
157
  if torch.cuda.is_available():
158
  power_device = "GPU"
159
  else:
160
  power_device = "CPU"
161
 
162
- with gr.Blocks(css=css, title="Generative Date Augmentation") as demo:
163
 
164
- with gr.Column(elem_id="col-container"):
165
- gr.Markdown(
166
- f"""
167
- # Data Augmentation with Image-to-Image Diffusion Models via Prompt Interpolation
168
- Currently running on {power_device}.
169
  """
170
- )
171
-
172
- input_image = gr.Image(type="pil", label="Image to Augment")
173
-
174
- with gr.Row():
175
- prompt1 = gr.Text(
176
- label="Prompt 1",
177
- show_label=True,
178
- max_lines=1,
179
- placeholder="Enter your first prompt",
180
- container=False,
181
- )
182
- with gr.Row():
183
- prompt2 = gr.Text(
184
- label="Prompt 2",
185
- show_label=True,
186
- max_lines=1,
187
- placeholder="Enter your second prompt",
188
- container=False,
189
- )
190
- with gr.Row():
191
- gr.Examples(
192
- examples=examples1, inputs=[prompt1], label="Example for Prompt 1"
193
- )
194
- gr.Examples(
195
- examples=examples2, inputs=[prompt2], label="Example for Prompt 2"
196
- )
197
-
198
- with gr.Row():
199
- num_interpolation_steps = gr.Slider(
200
- label="Total interpolation steps",
201
- minimum=2,
202
- maximum=32,
203
- step=2,
204
- value=16,
205
- )
206
- interpolation_step = gr.Slider(
207
- label="Specific Interpolation Step",
208
- minimum=1,
209
- maximum=8,
210
- step=1,
211
- value=8,
212
- )
213
- num_interpolation_steps.change(
214
- fn=update_steps,
215
- inputs=[num_interpolation_steps, interpolation_step],
216
- outputs=[interpolation_step],
217
- )
218
- run_button = gr.Button("Run", scale=0)
219
-
220
- result = gr.Image(label="Result", show_label=False)
221
-
222
- with gr.Accordion("Advanced Settings", open=True):
223
-
224
- negative_prompt = gr.Text(
225
- label="Negative prompt",
226
- max_lines=1,
227
- placeholder="Enter a negative prompt",
228
- visible=False,
229
- )
230
-
231
- seed = gr.Slider(
232
- label="Seed",
233
- minimum=0,
234
- maximum=MAX_SEED,
235
- step=1,
236
- value=0,
237
- )
238
 
239
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
240
 
241
- gr.Markdown("Negative Prompt: ")
242
  with gr.Row():
243
- negative_prompt = gr.Text(
244
- label="Negative Prompt",
245
  show_label=True,
246
  max_lines=1,
247
- value="blurry image, disfigured, deformed, distorted, cartoon, drawings",
248
  container=False,
249
  )
250
  with gr.Row():
251
-
252
- width = gr.Slider(
253
- label="Width",
254
- minimum=256,
255
- maximum=MAX_IMAGE_SIZE,
256
- step=32,
257
- value=512,
258
- )
259
-
260
- height = gr.Slider(
261
- label="Height",
262
- minimum=256,
263
- maximum=MAX_IMAGE_SIZE,
264
- step=32,
265
- value=512,
266
  )
267
-
268
  with gr.Row():
269
-
270
- guidance_scale = gr.Slider(
271
- label="Guidance scale",
272
- minimum=0.0,
273
- maximum=10.0,
274
- step=0.1,
275
- value=8.0,
276
  )
277
 
278
- num_inference_steps = gr.Slider(
279
- label="Number of inference steps",
 
280
  minimum=1,
281
- maximum=80,
282
  step=1,
283
- value=25,
284
  )
285
- with gr.Row():
286
- sample_mid_interpolation = gr.Slider(
287
- label="Number of sampling steps in the middle of interpolation",
288
  minimum=2,
289
- maximum=80,
290
  step=2,
291
  value=16,
292
  )
293
- with gr.Row():
294
- remove_n_middle = gr.Slider(
295
- label="Number of middle steps to remove from interpolation",
 
 
 
 
 
 
 
 
 
 
 
 
296
  minimum=0,
297
- maximum=80,
298
- step=2,
299
  value=0,
300
  )
301
- gr.Markdown(
302
- """
303
- Metadata:
304
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  )
306
- with gr.Row():
307
- show_seed = gr.Label(label="Seed:", value="Randomized seed")
308
- ssim_score = gr.Label(label="SSIM Score:", value="Generate to see score")
309
- cos_sim = gr.Label(label="CLIP Score:", value="Generate to see score")
310
- run_button.click(
311
- fn=infer,
312
- inputs=[
313
- input_image,
314
- prompt1,
315
- prompt2,
316
- negative_prompt,
317
- seed,
318
- randomize_seed,
319
- width,
320
- height,
321
- guidance_scale,
322
- interpolation_step,
323
- num_inference_steps,
324
- num_interpolation_steps,
325
- sample_mid_interpolation,
326
- remove_n_middle,
327
- ],
328
- outputs=[result, show_seed, ssim_score, cos_sim],
329
- )
330
 
331
- demo.queue().launch()
332
 
333
  """
334
  input_image,
 
50
  ):
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
 
53
+ # Input Validation
54
+ try:
55
+ assert num_interpolation_steps % 2 == 0
56
+ except AssertionError:
57
+ raise ValueError("num_interpolation_steps must be an even number")
58
+ try:
59
+ assert sample_mid_interpolation % 2 == 0
60
+ except AssertionError:
61
+ raise ValueError("sample_mid_interpolation must be an even number")
62
+ try:
63
+ assert remove_n_middle % 2 == 0
64
+ except AssertionError:
65
+ raise ValueError("remove_n_middle must be an even number")
66
+ try:
67
+ assert num_interpolation_steps >= sample_mid_interpolation
68
+ except AssertionError:
69
+ raise ValueError(
70
+ "num_interpolation_steps must be greater than or equal to sample_mid_interpolation"
71
+ )
72
+ try:
73
+ assert num_interpolation_steps >= 2 and sample_mid_interpolation >= 2
74
+ except AssertionError:
75
+ raise ValueError(
76
+ "num_interpolation_steps and sample_mid_interpolation must be greater than or equal to 2"
77
+ )
78
+ try:
79
+ assert sample_mid_interpolation - remove_n_middle >= 2
80
+ except AssertionError:
81
+ raise ValueError(
82
+ "sample_mid_interpolation must be greater than or equal to remove_n_middle + 2"
83
+ )
84
+
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
  prompts = [prompt1, prompt2]
88
  generator = torch.Generator().manual_seed(seed)
89
+
90
  interpolated_prompt_embeds, prompt_metadata = synth.interpolatePrompts(
91
  prompts,
92
  pipe,
 
97
  )
98
  negative_prompts = [negative_prompt, negative_prompt]
99
  if negative_prompts != ["", ""]:
100
+ interpolated_negative_prompts_embeds, _ = synth.interpolatePrompts(
101
+ negative_prompts,
102
+ pipe,
103
+ num_interpolation_steps,
104
+ sample_mid_interpolation,
105
+ remove_n_middle=remove_n_middle,
106
+ device=device,
 
 
107
  )
108
  else:
109
+ interpolated_negative_prompts_embeds, _ = [None] * len(
110
  interpolated_prompt_embeds
111
  ), None
112
 
 
159
  * 100
160
  )
161
 
162
+ return image, seed, round(ssim_score, 4), round(cosine_sim, 2)
163
 
164
 
165
  examples1 = [
 
171
  "A photo of a beagle, a type of dog",
172
  ]
173
 
 
 
 
 
 
 
174
 
175
  def update_steps(total_steps, interpolation_step):
176
  if interpolation_step > total_steps:
 
178
  return gr.update(maximum=total_steps // 2)
179
 
180
 
181
+ def update_sampling_steps(total_steps, sample_steps):
182
+ # if sample_steps > total_steps:
183
+ # return gr.update(value=total_steps)
184
+ return gr.update(value=total_steps)
185
+
186
+
187
  if torch.cuda.is_available():
188
  power_device = "GPU"
189
  else:
190
  power_device = "CPU"
191
 
192
+ with gr.Blocks(title="Generative Date Augmentation") as demo:
193
 
194
+ gr.Markdown(
 
 
 
 
195
  """
196
+ # Data Augmentation with Image-to-Image Diffusion Models via Prompt Interpolation.
197
+ """
198
+ )
199
+ with gr.Row():
200
+ with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ input_image = gr.Image(type="pil", label="Image to Augment")
203
 
 
204
  with gr.Row():
205
+ prompt1 = gr.Text(
206
+ label="Prompt 1",
207
  show_label=True,
208
  max_lines=1,
209
+ placeholder="Enter your first prompt",
210
  container=False,
211
  )
212
  with gr.Row():
213
+ prompt2 = gr.Text(
214
+ label="Prompt 2",
215
+ show_label=True,
216
+ max_lines=1,
217
+ placeholder="Enter your second prompt",
218
+ container=False,
 
 
 
 
 
 
 
 
 
219
  )
 
220
  with gr.Row():
221
+ gr.Examples(
222
+ examples=examples1, inputs=[prompt1], label="Example for Prompt 1"
223
+ )
224
+ gr.Examples(
225
+ examples=examples2, inputs=[prompt2], label="Example for Prompt 2"
 
 
226
  )
227
 
228
+ with gr.Row():
229
+ interpolation_step = gr.Slider(
230
+ label="Specific Interpolation Step",
231
  minimum=1,
232
+ maximum=8,
233
  step=1,
234
+ value=8,
235
  )
236
+ num_interpolation_steps = gr.Slider(
237
+ label="Total interpolation steps",
 
238
  minimum=2,
239
+ maximum=32,
240
  step=2,
241
  value=16,
242
  )
243
+ num_interpolation_steps.change(
244
+ fn=update_steps,
245
+ inputs=[num_interpolation_steps, interpolation_step],
246
+ outputs=[interpolation_step],
247
+ )
248
+ run_button = gr.Button("Run", scale=0)
249
+ with gr.Accordion("Advanced Settings", open=True):
250
+ negative_prompt = gr.Text(
251
+ label="Negative prompt",
252
+ max_lines=1,
253
+ placeholder="Enter a negative prompt",
254
+ visible=False,
255
+ )
256
+ seed = gr.Slider(
257
+ label="Seed",
258
  minimum=0,
259
+ maximum=MAX_SEED,
260
+ step=1,
261
  value=0,
262
  )
263
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
264
+ gr.Markdown("Negative Prompt: ")
265
+ with gr.Row():
266
+ negative_prompt = gr.Text(
267
+ label="Negative Prompt",
268
+ show_label=True,
269
+ max_lines=1,
270
+ value="blurry image, disfigured, deformed, distorted, cartoon, drawings",
271
+ container=False,
272
+ )
273
+ with gr.Row():
274
+ width = gr.Slider(
275
+ label="Width",
276
+ minimum=256,
277
+ maximum=MAX_IMAGE_SIZE,
278
+ step=32,
279
+ value=512,
280
+ )
281
+ height = gr.Slider(
282
+ label="Height",
283
+ minimum=256,
284
+ maximum=MAX_IMAGE_SIZE,
285
+ step=32,
286
+ value=512,
287
+ )
288
+ with gr.Row():
289
+ guidance_scale = gr.Slider(
290
+ label="Guidance scale",
291
+ minimum=0.0,
292
+ maximum=10.0,
293
+ step=0.1,
294
+ value=8.0,
295
+ )
296
+ num_inference_steps = gr.Slider(
297
+ label="Number of inference steps",
298
+ minimum=1,
299
+ maximum=80,
300
+ step=1,
301
+ value=25,
302
+ )
303
+ with gr.Row():
304
+ sample_mid_interpolation = gr.Slider(
305
+ label="Number of sampling steps in the middle of interpolation",
306
+ minimum=2,
307
+ maximum=80,
308
+ step=2,
309
+ value=16,
310
+ )
311
+ num_interpolation_steps.change(
312
+ fn=update_sampling_steps,
313
+ inputs=[num_interpolation_steps, sample_mid_interpolation],
314
+ outputs=[sample_mid_interpolation],
315
+ )
316
+ with gr.Row():
317
+ remove_n_middle = gr.Slider(
318
+ label="Number of middle steps to remove from interpolation",
319
+ minimum=0,
320
+ maximum=80,
321
+ step=2,
322
+ value=0,
323
+ )
324
+ with gr.Column():
325
+ result = gr.Image(label="Result", show_label=False)
326
+
327
+ gr.Markdown(
328
+ """
329
+ Metadata:
330
+ """
331
+ )
332
+ with gr.Row():
333
+ show_seed = gr.Label(label="Seed:", value="Randomized seed")
334
+ ssim_score = gr.Label(
335
+ label="SSIM Score:", value="Generate to see score"
336
+ )
337
+ cos_sim = gr.Label(label="CLIP Score:", value="Generate to see score")
338
+ gr.Markdown(
339
+ f"""
340
+ Currently running on {power_device}.
341
+ """
342
+ )
343
+
344
+ run_button.click(
345
+ fn=infer,
346
+ inputs=[
347
+ input_image,
348
+ prompt1,
349
+ prompt2,
350
+ negative_prompt,
351
+ seed,
352
+ randomize_seed,
353
+ width,
354
+ height,
355
+ guidance_scale,
356
+ interpolation_step,
357
+ num_inference_steps,
358
+ num_interpolation_steps,
359
+ sample_mid_interpolation,
360
+ remove_n_middle,
361
+ ],
362
+ outputs=[result, show_seed, ssim_score, cos_sim],
363
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
+ demo.queue().launch(show_error=True)
366
 
367
  """
368
  input_image,