zhiweili commited on
Commit
336094b
1 Parent(s): 5e90935

test app_ddim

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. app_ddim.py +260 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_haircolor_inpaint import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_ddim import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_ddim.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+ import numpy as np
6
+
7
+ from tqdm.auto import tqdm
8
+ from torchvision import transforms as tfms
9
+ from PIL import Image
10
+ from segment_utils import(
11
+ segment_image,
12
+ restore_result,
13
+ )
14
+ from diffusers import (
15
+ StableDiffusionPipeline,
16
+ DDIMScheduler,
17
+ )
18
+
19
+ BASE_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
20
+
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ DEFAULT_INPUT_PROMPT = "a woman"
24
+ DEFAULT_EDIT_PROMPT = "a woman with linen-blonde-hair"
25
+
26
+ DEFAULT_CATEGORY = "hair"
27
+
28
+ basepipeline = StableDiffusionPipeline.from_pretrained(
29
+ BASE_MODEL,
30
+ torch_dtype=torch.float16,
31
+ use_safetensors=True,
32
+ )
33
+
34
+ basepipeline.scheduler = DDIMScheduler.from_config(basepipeline.scheduler.config)
35
+
36
+ basepipeline = basepipeline.to(DEVICE)
37
+
38
+ basepipeline.enable_model_cpu_offload()
39
+
40
+ @spaces.GPU(duration=30)
41
+ def image_to_image(
42
+ input_image: Image,
43
+ input_image_prompt: str,
44
+ edit_prompt: str,
45
+ num_steps: int,
46
+ start_step: int,
47
+ guidance_scale: float,
48
+ ):
49
+ run_task_time = 0
50
+ time_cost_str = ''
51
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
52
+
53
+ with torch.no_grad():
54
+ latent = basepipeline.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(DEVICE) * 2 - 1)
55
+ l = 0.18215 * latent.latent_dist.sample()
56
+ inverted_latents = invert(l, input_image_prompt, num_inference_steps=num_steps)
57
+ generated_image = sample(
58
+ edit_prompt,
59
+ start_latents=inverted_latents[-(start_step + 1)][None],
60
+ start_step=start_step,
61
+ num_inference_steps=num_steps,
62
+ guidance_scale=guidance_scale,
63
+ )[0]
64
+
65
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
66
+
67
+ return generated_image, time_cost_str
68
+
69
+ def make_inpaint_condition(image, image_mask):
70
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
71
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
72
+
73
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
74
+ image[image_mask > 0.5] = -1.0 # set as masked pixel
75
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
76
+ image = torch.from_numpy(image)
77
+ return image
78
+
79
+ ## Inversion
80
+ @torch.no_grad()
81
+ def invert(
82
+ start_latents,
83
+ prompt,
84
+ guidance_scale=3.5,
85
+ num_inference_steps=80,
86
+ num_images_per_prompt=1,
87
+ do_classifier_free_guidance=True,
88
+ negative_prompt="",
89
+ device=DEVICE,
90
+ ):
91
+
92
+ # Encode prompt
93
+ text_embeddings = basepipeline._encode_prompt(
94
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
95
+ )
96
+
97
+ # Latents are now the specified start latents
98
+ latents = start_latents.clone()
99
+
100
+ # We'll keep a list of the inverted latents as the process goes on
101
+ intermediate_latents = []
102
+
103
+ # Set num inference steps
104
+ basepipeline.scheduler.set_timesteps(num_inference_steps, device=device)
105
+
106
+ # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
107
+ timesteps = reversed(basepipeline.scheduler.timesteps)
108
+
109
+ for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):
110
+
111
+ # We'll skip the final iteration
112
+ if i >= num_inference_steps - 1:
113
+ continue
114
+
115
+ t = timesteps[i]
116
+
117
+ # Expand the latents if we are doing classifier free guidance
118
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
119
+ latent_model_input = basepipeline.scheduler.scale_model_input(latent_model_input, t)
120
+
121
+ # Predict the noise residual
122
+ noise_pred = basepipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
123
+
124
+ # Perform guidance
125
+ if do_classifier_free_guidance:
126
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
127
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
128
+
129
+ current_t = max(0, t.item() - (1000 // num_inference_steps)) # t
130
+ next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t+1
131
+ alpha_t = basepipeline.scheduler.alphas_cumprod[current_t]
132
+ alpha_t_next = basepipeline.scheduler.alphas_cumprod[next_t]
133
+
134
+ # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
135
+ latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
136
+ 1 - alpha_t_next
137
+ ).sqrt() * noise_pred
138
+
139
+ # Store
140
+ intermediate_latents.append(latents)
141
+
142
+ return torch.cat(intermediate_latents)
143
+
144
+ # Sample function (regular DDIM)
145
+ @torch.no_grad()
146
+ def sample(
147
+ prompt,
148
+ start_step=0,
149
+ start_latents=None,
150
+ guidance_scale=3.5,
151
+ num_inference_steps=30,
152
+ num_images_per_prompt=1,
153
+ do_classifier_free_guidance=True,
154
+ negative_prompt="",
155
+ device=DEVICE,
156
+ ):
157
+
158
+ # Encode prompt
159
+ text_embeddings = basepipeline._encode_prompt(
160
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
161
+ )
162
+
163
+ # Set num inference steps
164
+ basepipeline.scheduler.set_timesteps(num_inference_steps, device=device)
165
+
166
+ # Create a random starting point if we don't have one already
167
+ if start_latents is None:
168
+ start_latents = torch.randn(1, 4, 64, 64, device=device)
169
+ start_latents *= basepipeline.scheduler.init_noise_sigma
170
+
171
+ latents = start_latents.clone()
172
+
173
+ for i in tqdm(range(start_step, num_inference_steps)):
174
+
175
+ t = basepipeline.scheduler.timesteps[i]
176
+
177
+ # Expand the latents if we are doing classifier free guidance
178
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
179
+ latent_model_input = basepipeline.scheduler.scale_model_input(latent_model_input, t)
180
+
181
+ # Predict the noise residual
182
+ noise_pred = basepipeline.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
183
+
184
+ # Perform guidance
185
+ if do_classifier_free_guidance:
186
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
187
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
188
+
189
+ # Normally we'd rely on the scheduler to handle the update step:
190
+ # latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
191
+
192
+ # Instead, let's do it ourselves:
193
+ prev_t = max(1, t.item() - (1000 // num_inference_steps)) # t-1
194
+ alpha_t = basepipeline.scheduler.alphas_cumprod[t.item()]
195
+ alpha_t_prev = basepipeline.scheduler.alphas_cumprod[prev_t]
196
+ predicted_x0 = (latents - (1 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
197
+ direction_pointing_to_xt = (1 - alpha_t_prev).sqrt() * noise_pred
198
+ latents = alpha_t_prev.sqrt() * predicted_x0 + direction_pointing_to_xt
199
+
200
+ # Post-processing
201
+ images = basepipeline.decode_latents(latents)
202
+ images = basepipeline.numpy_to_pil(images)
203
+
204
+ return images
205
+
206
+ def get_time_cost(run_task_time, time_cost_str):
207
+ now_time = int(time.time()*1000)
208
+ if run_task_time == 0:
209
+ time_cost_str = 'start'
210
+ else:
211
+ if time_cost_str != '':
212
+ time_cost_str += f'-->'
213
+ time_cost_str += f'{now_time - run_task_time}'
214
+ run_task_time = now_time
215
+ return run_task_time, time_cost_str
216
+
217
+ def create_demo() -> gr.Blocks:
218
+ with gr.Blocks() as demo:
219
+ croper = gr.State()
220
+ with gr.Row():
221
+ with gr.Column():
222
+ input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_INPUT_PROMPT)
223
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
224
+ with gr.Column():
225
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
226
+ start_step = gr.Slider(minimum=0, maximum=100, value=15, step=1, label="Start Step")
227
+ guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
228
+ with gr.Column():
229
+ generate_size = gr.Number(label="Generate Size", value=512)
230
+ with gr.Accordion("Advanced Options", open=False):
231
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
232
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
233
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
234
+ g_btn = gr.Button("Edit Image")
235
+
236
+ with gr.Row():
237
+ with gr.Column():
238
+ input_image = gr.Image(label="Input Image", type="pil")
239
+ with gr.Column():
240
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
241
+ with gr.Column():
242
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
243
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
244
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
245
+
246
+ g_btn.click(
247
+ fn=segment_image,
248
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
249
+ outputs=[origin_area_image, croper],
250
+ ).success(
251
+ fn=image_to_image,
252
+ inputs=[origin_area_image, input_image_prompt, edit_prompt, num_steps, start_step, guidance_scale],
253
+ outputs=[generated_image, generated_cost],
254
+ ).success(
255
+ fn=restore_result,
256
+ inputs=[croper, category, generated_image],
257
+ outputs=[restored_image],
258
+ )
259
+
260
+ return demo
requirements.txt CHANGED
@@ -8,4 +8,5 @@ mediapipe
8
  spaces
9
  sentencepiece
10
  controlnet_aux
11
- peft
 
 
8
  spaces
9
  sentencepiece
10
  controlnet_aux
11
+ peft
12
+ tqdm