zhiweili commited on
Commit
8be2702
1 Parent(s): 8f6203d

add step params

Browse files
Files changed (1) hide show
  1. app_diffedit.py +20 -5
app_diffedit.py CHANGED
@@ -29,6 +29,9 @@ def image_to_image(
29
  input_image: Image,
30
  source_prompt: str,
31
  target_prompt: str,
 
 
 
32
  ):
33
  run_task_time = 0
34
  time_cost_str = ''
@@ -39,17 +42,29 @@ def image_to_image(
39
  image=input_image,
40
  source_prompt=source_prompt,
41
  target_prompt=target_prompt,
 
 
42
  )
43
 
44
- inv_latents = basepipeline.invert(prompt=source_prompt, image=input_image).latents
 
 
 
 
 
 
 
 
45
 
46
  output_image = basepipeline(
47
  prompt=target_prompt,
48
  mask_image=mask_image,
49
  image_latents=inv_latents,
50
  negative_prompt=source_prompt,
 
 
51
  ).images[0]
52
- mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L")
53
 
54
 
55
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
@@ -74,10 +89,10 @@ def create_demo() -> gr.Blocks:
74
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
75
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
76
  with gr.Column():
77
- num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
78
  start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
79
  with gr.Column():
80
- seed = gr.Number(label="Seed", value=8)
81
  g_btn = gr.Button("Edit Image")
82
 
83
  with gr.Row():
@@ -91,7 +106,7 @@ def create_demo() -> gr.Blocks:
91
 
92
  g_btn.click(
93
  fn=image_to_image,
94
- inputs=[input_image, input_image_prompt, edit_prompt],
95
  outputs=[output_image, mask_image, generated_cost],
96
  )
97
 
 
29
  input_image: Image,
30
  source_prompt: str,
31
  target_prompt: str,
32
+ num_inference_steps: int,
33
+ start_step: int,
34
+ guidance_scale: float,
35
  ):
36
  run_task_time = 0
37
  time_cost_str = ''
 
42
  image=input_image,
43
  source_prompt=source_prompt,
44
  target_prompt=target_prompt,
45
+ num_inference_steps = num_inference_steps,
46
+ guidance_scale=guidance_scale,
47
  )
48
 
49
+ inv_latents = basepipeline.invert(
50
+ prompt=source_prompt,
51
+ image=input_image,
52
+ num_inference_steps = num_inference_steps,
53
+ guidance_scale=guidance_scale,
54
+ ).latents
55
+
56
+ # get inverse latents by start step
57
+ # inv_latents = inv_latents[-(start_step + 1)][None]
58
 
59
  output_image = basepipeline(
60
  prompt=target_prompt,
61
  mask_image=mask_image,
62
  image_latents=inv_latents,
63
  negative_prompt=source_prompt,
64
+ num_inference_steps = num_inference_steps,
65
+ guidance_scale=guidance_scale,
66
  ).images[0]
67
+ mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize(input_image.size, Image.LANCZOS)
68
 
69
 
70
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
89
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
90
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
91
  with gr.Column():
92
+ num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Inference Steps")
93
  start_step = gr.Slider(minimum=1, maximum=100, value=15, step=1, label="Start Step")
94
  with gr.Column():
95
+ guidance_scale = gr.Slider(minimum=0, maximum=20, value=7.5, step=0.5, label="Guidance Scale")
96
  g_btn = gr.Button("Edit Image")
97
 
98
  with gr.Row():
 
106
 
107
  g_btn.click(
108
  fn=image_to_image,
109
+ inputs=[input_image, input_image_prompt, edit_prompt, num_inference_steps, start_step, guidance_scale],
110
  outputs=[output_image, mask_image, generated_cost],
111
  )
112