michaelj commited on
Commit
2de5b8f
1 Parent(s): 9e7d17e
Files changed (1) hide show
  1. backend/utils_sd.py +1419 -0
backend/utils_sd.py ADDED
@@ -0,0 +1,1419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import imp
3
+ import numpy as np
4
+ import cv2
5
+ import torch
6
+ import random
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import copy
9
+ from typing import Optional, Union, Tuple, List, Callable, Dict, Any
10
+ from tqdm.notebook import tqdm
11
+ from diffusers.utils import BaseOutput, logging
12
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
13
+ from diffusers.models.unet_2d_blocks import (
14
+ CrossAttnDownBlock2D,
15
+ CrossAttnUpBlock2D,
16
+ DownBlock2D,
17
+ UNetMidBlock2DCrossAttn,
18
+ UpBlock2D,
19
+ get_down_block,
20
+ get_up_block,
21
+ )
22
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput, logger
23
+ from copy import deepcopy
24
+ import json
25
+
26
+ import inspect
27
+ import os
28
+ import warnings
29
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
30
+
31
+ import numpy as np
32
+ import PIL.Image
33
+ import torch
34
+ import torch.nn.functional as F
35
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
36
+
37
+ from diffusers.image_processor import VaeImageProcessor
38
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
39
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
40
+ from diffusers.schedulers import KarrasDiffusionSchedulers
41
+ from diffusers.utils.torch_utils import is_compiled_module
42
+
43
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
44
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
45
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
46
+ from tqdm import tqdm
47
+ from controlnet_aux import HEDdetector, OpenposeDetector
48
+ import time
49
+
50
+ def seed_everything(seed):
51
+ torch.manual_seed(seed)
52
+ torch.cuda.manual_seed(seed)
53
+ random.seed(seed)
54
+ np.random.seed(seed)
55
+
56
+ def get_promptls(prompt_path):
57
+ with open(prompt_path) as f:
58
+ prompt_ls = json.load(f)
59
+ prompt_ls = [prompt['caption'].replace('/','_') for prompt in prompt_ls]
60
+ return prompt_ls
61
+
62
+ def load_512(image_path, left=0, right=0, top=0, bottom=0):
63
+ # print(image_path)
64
+ if type(image_path) is str:
65
+ image = np.array(Image.open(image_path))
66
+ if image.ndim>3:
67
+ image = image[:,:,:3]
68
+ elif image.ndim == 2:
69
+ image = image.reshape(image.shape[0], image.shape[1],1).astype('uint8')
70
+ else:
71
+ image = image_path
72
+ h, w, c = image.shape
73
+ left = min(left, w-1)
74
+ right = min(right, w - left - 1)
75
+ top = min(top, h - left - 1)
76
+ bottom = min(bottom, h - top - 1)
77
+ image = image[top:h-bottom, left:w-right]
78
+ h, w, c = image.shape
79
+ if h < w:
80
+ offset = (w - h) // 2
81
+ image = image[:, offset:offset + h]
82
+ elif w < h:
83
+ offset = (h - w) // 2
84
+ image = image[offset:offset + w]
85
+ image = np.array(Image.fromarray(image).resize((512, 512)))
86
+ return image
87
+
88
+ def get_canny(image_path):
89
+ image = load_512(
90
+ image_path
91
+ )
92
+ image = np.array(image)
93
+
94
+ # get canny image
95
+ image = cv2.Canny(image, 100, 200)
96
+ image = image[:, :, None]
97
+ image = np.concatenate([image, image, image], axis=2)
98
+ canny_image = Image.fromarray(image)
99
+ return canny_image
100
+
101
+
102
+ def get_scribble(image_path, hed):
103
+ image = load_512(
104
+ image_path
105
+ )
106
+ image = hed(image, scribble=True)
107
+
108
+ return image
109
+
110
+ def get_cocoimages(prompt_path):
111
+ data_ls = []
112
+ with open(prompt_path) as f:
113
+ prompt_ls = json.load(f)
114
+ img_path = 'COCO2017-val/val2017'
115
+ for prompt in tqdm(prompt_ls):
116
+ caption = prompt['caption'].replace('/','_')
117
+ image_id = str(prompt['image_id'])
118
+ image_id = (12-len(image_id))*'0' + image_id+'.jpg'
119
+ image_path = os.path.join(img_path, image_id)
120
+ try:
121
+ image = get_canny(image_path)
122
+ except:
123
+ continue
124
+ curr_data = {'image':image, 'prompt':caption}
125
+ data_ls.append(curr_data)
126
+ return data_ls
127
+
128
+ def get_cocoimages2(prompt_path):
129
+ """scribble condition
130
+ """
131
+ data_ls = []
132
+ with open(prompt_path) as f:
133
+ prompt_ls = json.load(f)
134
+ img_path = 'COCO2017-val/val2017'
135
+ hed = HEDdetector.from_pretrained('ControlNet/detector_weights/annotator', filename='network-bsds500.pth')
136
+ for prompt in tqdm(prompt_ls):
137
+ caption = prompt['caption'].replace('/','_')
138
+ image_id = str(prompt['image_id'])
139
+ image_id = (12-len(image_id))*'0' + image_id+'.jpg'
140
+ image_path = os.path.join(img_path, image_id)
141
+ try:
142
+ image = get_scribble(image_path,hed)
143
+ except:
144
+ continue
145
+ curr_data = {'image':image, 'prompt':caption}
146
+ data_ls.append(curr_data)
147
+ return data_ls
148
+
149
+ def warpped_feature(sample, step):
150
+ """
151
+ sample: batch_size*dim*h*w, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size
152
+ step: timestep span
153
+ """
154
+ bs, dim, h, w = sample.shape
155
+ uncond_fea, cond_fea = sample.chunk(2)
156
+ uncond_fea = uncond_fea.repeat(step,1,1,1) # (step * bs//2) * dim * h *w
157
+ cond_fea = cond_fea.repeat(step,1,1,1) # (step * bs//2) * dim * h *w
158
+ return torch.cat([uncond_fea, cond_fea])
159
+
160
+ def warpped_skip_feature(block_samples, step):
161
+ down_block_res_samples = []
162
+ for sample in block_samples:
163
+ sample_expand = warpped_feature(sample, step)
164
+ down_block_res_samples.append(sample_expand)
165
+ return tuple(down_block_res_samples)
166
+
167
+ def warpped_text_emb(text_emb, step):
168
+ """
169
+ text_emb: batch_size*77*768, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size
170
+ step: timestep span
171
+ """
172
+ bs, token_len, dim = text_emb.shape
173
+ uncond_fea, cond_fea = text_emb.chunk(2)
174
+ uncond_fea = uncond_fea.repeat(step,1,1) # (step * bs//2) * 77 *768
175
+ cond_fea = cond_fea.repeat(step,1,1) # (step * bs//2) * 77 * 768
176
+ return torch.cat([uncond_fea, cond_fea]) # (step*bs) * 77 *768
177
+
178
+ def warpped_timestep(timesteps, bs):
179
+ """
180
+ timestpes: list, such as [981, 961, 941]
181
+ """
182
+ semi_bs = bs//2
183
+ ts = []
184
+ for timestep in timesteps:
185
+ timestep = timestep[None]
186
+ texp = timestep.expand(semi_bs)
187
+ ts.append(texp)
188
+ timesteps = torch.cat(ts)
189
+ return timesteps.repeat(2,1).reshape(-1)
190
+
191
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
192
+ """
193
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
194
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
195
+ """
196
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
197
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
198
+ # rescale the results from guidance (fixes overexposure)
199
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
200
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
201
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
202
+ return noise_cfg
203
+
204
+ def register_normal_pipeline(pipe):
205
+ def new_call(self):
206
+ @torch.no_grad()
207
+ def call(
208
+ prompt: Union[str, List[str]] = None,
209
+ height: Optional[int] = None,
210
+ width: Optional[int] = None,
211
+ num_inference_steps: int = 50,
212
+ guidance_scale: float = 7.5,
213
+ negative_prompt: Optional[Union[str, List[str]]] = None,
214
+ num_images_per_prompt: Optional[int] = 1,
215
+ eta: float = 0.0,
216
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
217
+ latents: Optional[torch.FloatTensor] = None,
218
+ prompt_embeds: Optional[torch.FloatTensor] = None,
219
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
220
+ output_type: Optional[str] = "pil",
221
+ return_dict: bool = True,
222
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
223
+ guidance_rescale: float = 0.0,
224
+ clip_skip: Optional[int] = None,
225
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
226
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
227
+ **kwargs,
228
+ ):
229
+
230
+ callback = kwargs.pop("callback", None)
231
+ callback_steps = kwargs.pop("callback_steps", None)
232
+
233
+
234
+ # 0. Default height and width to unet
235
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
236
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
237
+ # to deal with lora scaling and other possible forward hooks
238
+
239
+ # 1. Check inputs. Raise error if not correct
240
+ self.check_inputs(
241
+ prompt,
242
+ height,
243
+ width,
244
+ callback_steps,
245
+ negative_prompt,
246
+ prompt_embeds,
247
+ negative_prompt_embeds,
248
+ callback_on_step_end_tensor_inputs,
249
+ )
250
+
251
+ self._guidance_scale = guidance_scale
252
+ self._guidance_rescale = guidance_rescale
253
+ self._clip_skip = clip_skip
254
+ self._cross_attention_kwargs = cross_attention_kwargs
255
+
256
+ # 2. Define call parameters
257
+ if prompt is not None and isinstance(prompt, str):
258
+ batch_size = 1
259
+ elif prompt is not None and isinstance(prompt, list):
260
+ batch_size = len(prompt)
261
+ else:
262
+ batch_size = prompt_embeds.shape[0]
263
+
264
+ device = self._execution_device
265
+
266
+ # 3. Encode input prompt
267
+ lora_scale = (
268
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
269
+ )
270
+
271
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
272
+ prompt,
273
+ device,
274
+ num_images_per_prompt,
275
+ self.do_classifier_free_guidance,
276
+ negative_prompt,
277
+ prompt_embeds=prompt_embeds,
278
+ negative_prompt_embeds=negative_prompt_embeds,
279
+ lora_scale=lora_scale,
280
+ clip_skip=self.clip_skip,
281
+ )
282
+ # For classifier free guidance, we need to do two forward passes.
283
+ # Here we concatenate the unconditional and text embeddings into a single batch
284
+ # to avoid doing two forward passes
285
+ if self.do_classifier_free_guidance:
286
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
287
+
288
+ # 4. Prepare timesteps
289
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
290
+ timesteps = self.scheduler.timesteps
291
+
292
+ # 5. Prepare latent variables
293
+ num_channels_latents = self.unet.config.in_channels
294
+ latents = self.prepare_latents(
295
+ batch_size * num_images_per_prompt,
296
+ num_channels_latents,
297
+ height,
298
+ width,
299
+ prompt_embeds.dtype,
300
+ device,
301
+ generator,
302
+ latents,
303
+ )
304
+
305
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
306
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
307
+
308
+ # 6.5 Optionally get Guidance Scale Embedding
309
+ timestep_cond = None
310
+ if self.unet.config.time_cond_proj_dim is not None:
311
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
312
+ timestep_cond = self.get_guidance_scale_embedding(
313
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
314
+ ).to(device=device, dtype=latents.dtype)
315
+
316
+ # 7. Denoising loop
317
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
318
+ self._num_timesteps = len(timesteps)
319
+ init_latents = latents.detach().clone()
320
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
321
+ for i, t in enumerate(timesteps):
322
+ if t/1000 < 0.5:
323
+ latents = latents + 0.003*init_latents
324
+ setattr(self.unet, 'order', i)
325
+ # expand the latents if we are doing classifier free guidance
326
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
327
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
328
+
329
+ # predict the noise residual
330
+ noise_pred = self.unet(
331
+ latent_model_input,
332
+ t,
333
+ encoder_hidden_states=prompt_embeds,
334
+ timestep_cond=timestep_cond,
335
+ cross_attention_kwargs=self.cross_attention_kwargs,
336
+ return_dict=False,
337
+ )[0]
338
+
339
+ # perform guidance
340
+ if self.do_classifier_free_guidance:
341
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
342
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
343
+
344
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
345
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
346
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
347
+
348
+ # compute the previous noisy sample x_t -> x_t-1
349
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
350
+
351
+ if callback_on_step_end is not None:
352
+ callback_kwargs = {}
353
+ for k in callback_on_step_end_tensor_inputs:
354
+ callback_kwargs[k] = locals()[k]
355
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
356
+
357
+ latents = callback_outputs.pop("latents", latents)
358
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
359
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
360
+
361
+ # call the callback, if provided
362
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
363
+ progress_bar.update()
364
+ if callback is not None and i % callback_steps == 0:
365
+ step_idx = i // getattr(self.scheduler, "order", 1)
366
+ callback(step_idx, t, latents)
367
+
368
+ if not output_type == "latent":
369
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
370
+ 0
371
+ ]
372
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
373
+ else:
374
+ image = latents
375
+ has_nsfw_concept = None
376
+
377
+ if has_nsfw_concept is None:
378
+ do_denormalize = [True] * image.shape[0]
379
+ else:
380
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
381
+
382
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
383
+
384
+ # Offload all models
385
+ self.maybe_free_model_hooks()
386
+
387
+ if not return_dict:
388
+ return (image, has_nsfw_concept)
389
+
390
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
391
+ return call
392
+ pipe.call = new_call(pipe)
393
+
394
+
395
+ def register_parallel_pipeline(pipe):
396
+ def new_call(self):
397
+ @torch.no_grad()
398
+ def call(
399
+ prompt: Union[str, List[str]] = None,
400
+ height: Optional[int] = None,
401
+ width: Optional[int] = None,
402
+ num_inference_steps: int = 50,
403
+ guidance_scale: float = 7.5,
404
+ negative_prompt: Optional[Union[str, List[str]]] = None,
405
+ num_images_per_prompt: Optional[int] = 1,
406
+ eta: float = 0.0,
407
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
408
+ latents: Optional[torch.FloatTensor] = None,
409
+ prompt_embeds: Optional[torch.FloatTensor] = None,
410
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
411
+ output_type: Optional[str] = "pil",
412
+ return_dict: bool = True,
413
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
414
+ guidance_rescale: float = 0.0,
415
+ clip_skip: Optional[int] = None,
416
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
417
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
418
+ **kwargs,
419
+ ):
420
+
421
+ callback = kwargs.pop("callback", None)
422
+ callback_steps = kwargs.pop("callback_steps", None)
423
+
424
+
425
+ # 0. Default height and width to unet
426
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
427
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
428
+ # to deal with lora scaling and other possible forward hooks
429
+
430
+ # 1. Check inputs. Raise error if not correct
431
+ self.check_inputs(
432
+ prompt,
433
+ height,
434
+ width,
435
+ callback_steps,
436
+ negative_prompt,
437
+ prompt_embeds,
438
+ negative_prompt_embeds,
439
+ callback_on_step_end_tensor_inputs,
440
+ )
441
+
442
+ self._guidance_scale = guidance_scale
443
+ self._guidance_rescale = guidance_rescale
444
+ self._clip_skip = clip_skip
445
+ self._cross_attention_kwargs = cross_attention_kwargs
446
+
447
+ # 2. Define call parameters
448
+ if prompt is not None and isinstance(prompt, str):
449
+ batch_size = 1
450
+ elif prompt is not None and isinstance(prompt, list):
451
+ batch_size = len(prompt)
452
+ else:
453
+ batch_size = prompt_embeds.shape[0]
454
+
455
+ device = self._execution_device
456
+
457
+ # 3. Encode input prompt
458
+ lora_scale = (
459
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
460
+ )
461
+
462
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
463
+ prompt,
464
+ device,
465
+ num_images_per_prompt,
466
+ self.do_classifier_free_guidance,
467
+ negative_prompt,
468
+ prompt_embeds=prompt_embeds,
469
+ negative_prompt_embeds=negative_prompt_embeds,
470
+ lora_scale=lora_scale,
471
+ clip_skip=self.clip_skip,
472
+ )
473
+ # For classifier free guidance, we need to do two forward passes.
474
+ # Here we concatenate the unconditional and text embeddings into a single batch
475
+ # to avoid doing two forward passes
476
+ if self.do_classifier_free_guidance:
477
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
478
+
479
+ # 4. Prepare timesteps
480
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
481
+ timesteps = self.scheduler.timesteps
482
+
483
+ # 5. Prepare latent variables
484
+ num_channels_latents = self.unet.config.in_channels
485
+ latents = self.prepare_latents(
486
+ batch_size * num_images_per_prompt,
487
+ num_channels_latents,
488
+ height,
489
+ width,
490
+ prompt_embeds.dtype,
491
+ device,
492
+ generator,
493
+ latents,
494
+ )
495
+
496
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
497
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
498
+
499
+ # 6.5 Optionally get Guidance Scale Embedding
500
+ timestep_cond = None
501
+ if self.unet.config.time_cond_proj_dim is not None:
502
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
503
+ timestep_cond = self.get_guidance_scale_embedding(
504
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
505
+ ).to(device=device, dtype=latents.dtype)
506
+
507
+ # 7. Denoising loop
508
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
509
+ self._num_timesteps = len(timesteps)
510
+ init_latents = latents.detach().clone()
511
+ #-------------------------------------------------------
512
+ all_steps = len(self.scheduler.timesteps)
513
+ curr_span = 1
514
+ curr_step = 0
515
+
516
+ # st = time.time()
517
+ idx = 1
518
+ keytime = [0,1,2,3,5,10,15,25,35]
519
+ keytime.append(all_steps)
520
+ while curr_step<all_steps:
521
+ refister_time(self.unet, curr_step)
522
+
523
+ merge_span = curr_span
524
+ if merge_span>0:
525
+ time_ls = []
526
+ for i in range(curr_step, curr_step+merge_span):
527
+ if i<all_steps:
528
+ time_ls.append(self.scheduler.timesteps[i])
529
+ else:
530
+ break
531
+
532
+ ##--------------------------------
533
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
534
+
535
+ # predict the noise residual
536
+ noise_pred = self.unet(
537
+ latent_model_input,
538
+ time_ls,
539
+ encoder_hidden_states=prompt_embeds,
540
+ timestep_cond=timestep_cond,
541
+ cross_attention_kwargs=self.cross_attention_kwargs,
542
+ return_dict=False,
543
+ )[0]
544
+
545
+ # perform guidance
546
+ if self.do_classifier_free_guidance:
547
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
548
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
549
+
550
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
551
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
552
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
553
+
554
+ # compute the previous noisy sample x_t -> x_t-1
555
+
556
+ step_span = len(time_ls)
557
+ bs = noise_pred.shape[0]
558
+ bs_perstep = bs//step_span
559
+
560
+ denoised_latent = latents
561
+ for i, timestep in enumerate(time_ls):
562
+ if timestep/1000 < 0.5:
563
+ denoised_latent = denoised_latent + 0.003*init_latents
564
+ curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep]
565
+ denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0]
566
+
567
+ latents = denoised_latent
568
+ ##----------------------------------------
569
+ curr_step += curr_span
570
+ idx += 1
571
+
572
+ if curr_step<all_steps:
573
+ curr_span = keytime[idx] - keytime[idx-1]
574
+
575
+
576
+ if not output_type == "latent":
577
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
578
+ 0
579
+ ]
580
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
581
+ else:
582
+ image = latents
583
+ has_nsfw_concept = None
584
+
585
+ if has_nsfw_concept is None:
586
+ do_denormalize = [True] * image.shape[0]
587
+ else:
588
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
589
+
590
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
591
+
592
+ # Offload all models
593
+ self.maybe_free_model_hooks()
594
+
595
+ if not return_dict:
596
+ return (image, has_nsfw_concept)
597
+
598
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
599
+ return call
600
+ pipe.call = new_call(pipe)
601
+
602
+ def register_faster_forward(model, mod = '50ls'):
603
+ def faster_forward(self):
604
+ def forward(
605
+ sample: torch.FloatTensor,
606
+ timestep: Union[torch.Tensor, float, int],
607
+ encoder_hidden_states: torch.Tensor,
608
+ class_labels: Optional[torch.Tensor] = None,
609
+ timestep_cond: Optional[torch.Tensor] = None,
610
+ attention_mask: Optional[torch.Tensor] = None,
611
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
612
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
613
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
614
+ return_dict: bool = True,
615
+ ) -> Union[UNet2DConditionOutput, Tuple]:
616
+ r"""
617
+ Args:
618
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
619
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
620
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
621
+ return_dict (`bool`, *optional*, defaults to `True`):
622
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
623
+ cross_attention_kwargs (`dict`, *optional*):
624
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
625
+ `self.processor` in
626
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
627
+
628
+ Returns:
629
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
630
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
631
+ returning a tuple, the first element is the sample tensor.
632
+ """
633
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
634
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
635
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
636
+ # on the fly if necessary.
637
+ default_overall_up_factor = 2**self.num_upsamplers
638
+
639
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
640
+ forward_upsample_size = False
641
+ upsample_size = None
642
+
643
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
644
+ logger.info("Forward upsample size to force interpolation output size.")
645
+ forward_upsample_size = True
646
+
647
+ # prepare attention_mask
648
+ if attention_mask is not None:
649
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
650
+ attention_mask = attention_mask.unsqueeze(1)
651
+
652
+ # 0. center input if necessary
653
+ if self.config.center_input_sample:
654
+ sample = 2 * sample - 1.0
655
+
656
+ # 1. time
657
+ if isinstance(timestep, list):
658
+ timesteps = timestep[0]
659
+ step = len(timestep)
660
+ else:
661
+ timesteps = timestep
662
+ step = 1
663
+ if not torch.is_tensor(timesteps) and (not isinstance(timesteps,list)):
664
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
665
+ # This would be a good case for the `match` statement (Python 3.10+)
666
+ is_mps = sample.device.type == "mps"
667
+ if isinstance(timestep, float):
668
+ dtype = torch.float32 if is_mps else torch.float64
669
+ else:
670
+ dtype = torch.int32 if is_mps else torch.int64
671
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
672
+ elif (not isinstance(timesteps,list)) and len(timesteps.shape) == 0:
673
+ timesteps = timesteps[None].to(sample.device)
674
+
675
+ if (not isinstance(timesteps,list)) and len(timesteps.shape) == 1:
676
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
677
+ timesteps = timesteps.expand(sample.shape[0])
678
+ elif isinstance(timesteps, list):
679
+ #timesteps list, such as [981,961,941]
680
+ timesteps = warpped_timestep(timesteps, sample.shape[0]).to(sample.device)
681
+ t_emb = self.time_proj(timesteps)
682
+
683
+ # `Timesteps` does not contain any weights and will always return f32 tensors
684
+ # but time_embedding might actually be running in fp16. so we need to cast here.
685
+ # there might be better ways to encapsulate this.
686
+ t_emb = t_emb.to(dtype=self.dtype)
687
+
688
+ emb = self.time_embedding(t_emb, timestep_cond)
689
+
690
+ if self.class_embedding is not None:
691
+ if class_labels is None:
692
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
693
+
694
+ if self.config.class_embed_type == "timestep":
695
+ class_labels = self.time_proj(class_labels)
696
+
697
+ # `Timesteps` does not contain any weights and will always return f32 tensors
698
+ # there might be better ways to encapsulate this.
699
+ class_labels = class_labels.to(dtype=sample.dtype)
700
+
701
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
702
+
703
+ if self.config.class_embeddings_concat:
704
+ emb = torch.cat([emb, class_emb], dim=-1)
705
+ else:
706
+ emb = emb + class_emb
707
+
708
+ if self.config.addition_embed_type == "text":
709
+ aug_emb = self.add_embedding(encoder_hidden_states)
710
+ emb = emb + aug_emb
711
+
712
+ if self.time_embed_act is not None:
713
+ emb = self.time_embed_act(emb)
714
+
715
+ if self.encoder_hid_proj is not None:
716
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
717
+
718
+ #===============
719
+ order = self.order #timestep, start by 0
720
+ #===============
721
+ ipow = int(np.sqrt(9 + 8*order))
722
+ cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35]
723
+ if isinstance(mod, int):
724
+ cond = order % mod == 0
725
+ elif mod == "pro":
726
+ cond = ipow * ipow == (9 + 8 * order)
727
+ elif mod == "50ls":
728
+ cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40]
729
+ elif mod == "50ls2":
730
+ cond = order in [0, 10, 11, 12, 15, 20, 25, 30,35,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40]
731
+ elif mod == "50ls3":
732
+ cond = order in [0, 20, 25, 30,35,45,46,47,48,49] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40]
733
+ elif mod == "50ls4":
734
+ cond = order in [0, 9, 13, 14, 15, 28, 29, 32, 36,45] #40 #[0,1,2,3, 5, 10, 15] #[0, 1, 2, 3, 5, 10, 15, 25, 35, 40]
735
+ elif mod == "100ls":
736
+ cond = order > 85 or order < 10 or order % 5 == 0
737
+ elif mod == "75ls":
738
+ cond = order > 65 or order < 10 or order % 5 == 0
739
+ elif mod == "s2":
740
+ cond = order < 20 or order > 40 or order % 2 == 0
741
+
742
+ if cond:
743
+ print(order)
744
+ # 2. pre-process
745
+ sample = self.conv_in(sample)
746
+
747
+ # 3. down
748
+ down_block_res_samples = (sample,)
749
+ for downsample_block in self.down_blocks:
750
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
751
+ sample, res_samples = downsample_block(
752
+ hidden_states=sample,
753
+ temb=emb,
754
+ encoder_hidden_states=encoder_hidden_states,
755
+ attention_mask=attention_mask,
756
+ cross_attention_kwargs=cross_attention_kwargs,
757
+ )
758
+ else:
759
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
760
+
761
+ down_block_res_samples += res_samples
762
+
763
+ if down_block_additional_residuals is not None:
764
+ new_down_block_res_samples = ()
765
+
766
+ for down_block_res_sample, down_block_additional_residual in zip(
767
+ down_block_res_samples, down_block_additional_residuals
768
+ ):
769
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
770
+ new_down_block_res_samples += (down_block_res_sample,)
771
+
772
+ down_block_res_samples = new_down_block_res_samples
773
+
774
+ # 4. mid
775
+ if self.mid_block is not None:
776
+ sample = self.mid_block(
777
+ sample,
778
+ emb,
779
+ encoder_hidden_states=encoder_hidden_states,
780
+ attention_mask=attention_mask,
781
+ cross_attention_kwargs=cross_attention_kwargs,
782
+ )
783
+
784
+ if mid_block_additional_residual is not None:
785
+ sample = sample + mid_block_additional_residual
786
+
787
+ #----------------------save feature-------------------------
788
+ # setattr(self, 'skip_feature', (tmp_sample.clone() for tmp_sample in down_block_res_samples))
789
+ setattr(self, 'skip_feature', deepcopy(down_block_res_samples))
790
+ setattr(self, 'toup_feature', sample.detach().clone())
791
+ #-----------------------save feature------------------------
792
+
793
+
794
+
795
+ #-------------------expand feature for parallel---------------
796
+ if isinstance(timestep, list):
797
+ #timesteps list, such as [981,961,941]
798
+ timesteps = warpped_timestep(timestep, sample.shape[0]).to(sample.device)
799
+ t_emb = self.time_proj(timesteps)
800
+
801
+ # `Timesteps` does not contain any weights and will always return f32 tensors
802
+ # but time_embedding might actually be running in fp16. so we need to cast here.
803
+ # there might be better ways to encapsulate this.
804
+ t_emb = t_emb.to(dtype=self.dtype)
805
+
806
+ emb = self.time_embedding(t_emb, timestep_cond)
807
+ # print(emb.shape)
808
+
809
+ # print(step, sample.shape)
810
+ down_block_res_samples = warpped_skip_feature(down_block_res_samples, step)
811
+ sample = warpped_feature(sample, step)
812
+ # print(step, sample.shape)
813
+
814
+ encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step)
815
+
816
+ # print(emb.shape)
817
+
818
+ #-------------------expand feature for parallel---------------
819
+
820
+ else:
821
+ down_block_res_samples = self.skip_feature
822
+ sample = self.toup_feature
823
+
824
+ #-------------------expand feature for parallel---------------
825
+ down_block_res_samples = warpped_skip_feature(down_block_res_samples, step)
826
+ sample = warpped_feature(sample, step)
827
+ encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step)
828
+ #-------------------expand feature for parallel---------------
829
+
830
+ # 5. up
831
+ for i, upsample_block in enumerate(self.up_blocks):
832
+ is_final_block = i == len(self.up_blocks) - 1
833
+
834
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
835
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
836
+
837
+ # if we have not reached the final block and need to forward the
838
+ # upsample size, we do it here
839
+ if not is_final_block and forward_upsample_size:
840
+ upsample_size = down_block_res_samples[-1].shape[2:]
841
+
842
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
843
+ sample = upsample_block(
844
+ hidden_states=sample,
845
+ temb=emb,
846
+ res_hidden_states_tuple=res_samples,
847
+ encoder_hidden_states=encoder_hidden_states,
848
+ cross_attention_kwargs=cross_attention_kwargs,
849
+ upsample_size=upsample_size,
850
+ attention_mask=attention_mask,
851
+ )
852
+ else:
853
+ sample = upsample_block(
854
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
855
+ )
856
+
857
+ # 6. post-process
858
+ if self.conv_norm_out:
859
+ sample = self.conv_norm_out(sample)
860
+ sample = self.conv_act(sample)
861
+ sample = self.conv_out(sample)
862
+
863
+ if not return_dict:
864
+ return (sample,)
865
+
866
+ return UNet2DConditionOutput(sample=sample)
867
+ return forward
868
+ if model.__class__.__name__ == 'UNet2DConditionModel':
869
+ model.forward = faster_forward(model)
870
+
871
+ def register_normal_forward(model):
872
+ def normal_forward(self):
873
+ def forward(
874
+ sample: torch.FloatTensor,
875
+ timestep: Union[torch.Tensor, float, int],
876
+ encoder_hidden_states: torch.Tensor,
877
+ class_labels: Optional[torch.Tensor] = None,
878
+ timestep_cond: Optional[torch.Tensor] = None,
879
+ attention_mask: Optional[torch.Tensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
882
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
883
+ return_dict: bool = True,
884
+ ) -> Union[UNet2DConditionOutput, Tuple]:
885
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
886
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
887
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
888
+ # on the fly if necessary.
889
+ default_overall_up_factor = 2**self.num_upsamplers
890
+
891
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
892
+ forward_upsample_size = False
893
+ upsample_size = None
894
+ #---------------------
895
+ # import os
896
+ # os.makedirs(f'{timestep.item()}_step', exist_ok=True)
897
+ #---------------------
898
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
899
+ logger.info("Forward upsample size to force interpolation output size.")
900
+ forward_upsample_size = True
901
+
902
+ # prepare attention_mask
903
+ if attention_mask is not None:
904
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
905
+ attention_mask = attention_mask.unsqueeze(1)
906
+
907
+ # 0. center input if necessary
908
+ if self.config.center_input_sample:
909
+ sample = 2 * sample - 1.0
910
+
911
+ # 1. time
912
+ timesteps = timestep
913
+ if not torch.is_tensor(timesteps):
914
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
915
+ # This would be a good case for the `match` statement (Python 3.10+)
916
+ is_mps = sample.device.type == "mps"
917
+ if isinstance(timestep, float):
918
+ dtype = torch.float32 if is_mps else torch.float64
919
+ else:
920
+ dtype = torch.int32 if is_mps else torch.int64
921
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
922
+ elif len(timesteps.shape) == 0:
923
+ timesteps = timesteps[None].to(sample.device)
924
+
925
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
926
+ timesteps = timesteps.expand(sample.shape[0])
927
+
928
+ t_emb = self.time_proj(timesteps)
929
+
930
+ # `Timesteps` does not contain any weights and will always return f32 tensors
931
+ # but time_embedding might actually be running in fp16. so we need to cast here.
932
+ # there might be better ways to encapsulate this.
933
+ t_emb = t_emb.to(dtype=self.dtype)
934
+
935
+ emb = self.time_embedding(t_emb, timestep_cond)
936
+
937
+ if self.class_embedding is not None:
938
+ if class_labels is None:
939
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
940
+
941
+ if self.config.class_embed_type == "timestep":
942
+ class_labels = self.time_proj(class_labels)
943
+
944
+ # `Timesteps` does not contain any weights and will always return f32 tensors
945
+ # there might be better ways to encapsulate this.
946
+ class_labels = class_labels.to(dtype=sample.dtype)
947
+
948
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
949
+
950
+ if self.config.class_embeddings_concat:
951
+ emb = torch.cat([emb, class_emb], dim=-1)
952
+ else:
953
+ emb = emb + class_emb
954
+
955
+ if self.config.addition_embed_type == "text":
956
+ aug_emb = self.add_embedding(encoder_hidden_states)
957
+ emb = emb + aug_emb
958
+
959
+ if self.time_embed_act is not None:
960
+ emb = self.time_embed_act(emb)
961
+
962
+ if self.encoder_hid_proj is not None:
963
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
964
+
965
+ # 2. pre-process
966
+ sample = self.conv_in(sample)
967
+
968
+ # 3. down
969
+ down_block_res_samples = (sample,)
970
+ for i, downsample_block in enumerate(self.down_blocks):
971
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
972
+ sample, res_samples = downsample_block(
973
+ hidden_states=sample,
974
+ temb=emb,
975
+ encoder_hidden_states=encoder_hidden_states,
976
+ attention_mask=attention_mask,
977
+ cross_attention_kwargs=cross_attention_kwargs,
978
+ )
979
+ else:
980
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
981
+ #---------------------------------
982
+ # torch.save(sample, f'{timestep.item()}_step/down_{i}.pt')
983
+ #----------------------------------
984
+ down_block_res_samples += res_samples
985
+
986
+ if down_block_additional_residuals is not None:
987
+ new_down_block_res_samples = ()
988
+
989
+ for down_block_res_sample, down_block_additional_residual in zip(
990
+ down_block_res_samples, down_block_additional_residuals
991
+ ):
992
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
993
+ new_down_block_res_samples += (down_block_res_sample,)
994
+
995
+ down_block_res_samples = new_down_block_res_samples
996
+
997
+ # 4. mid
998
+ if self.mid_block is not None:
999
+ sample = self.mid_block(
1000
+ sample,
1001
+ emb,
1002
+ encoder_hidden_states=encoder_hidden_states,
1003
+ attention_mask=attention_mask,
1004
+ cross_attention_kwargs=cross_attention_kwargs,
1005
+ )
1006
+ # torch.save(sample, f'{timestep.item()}_step/mid.pt')
1007
+ if mid_block_additional_residual is not None:
1008
+ sample = sample + mid_block_additional_residual
1009
+ # 5. up
1010
+ for i, upsample_block in enumerate(self.up_blocks):
1011
+ is_final_block = i == len(self.up_blocks) - 1
1012
+
1013
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1014
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1015
+
1016
+ # if we have not reached the final block and need to forward the
1017
+ # upsample size, we do it here
1018
+ if not is_final_block and forward_upsample_size:
1019
+ upsample_size = down_block_res_samples[-1].shape[2:]
1020
+
1021
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1022
+ sample = upsample_block(
1023
+ hidden_states=sample,
1024
+ temb=emb,
1025
+ res_hidden_states_tuple=res_samples,
1026
+ encoder_hidden_states=encoder_hidden_states,
1027
+ cross_attention_kwargs=cross_attention_kwargs,
1028
+ upsample_size=upsample_size,
1029
+ attention_mask=attention_mask,
1030
+ )
1031
+ else:
1032
+ sample = upsample_block(
1033
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1034
+ )
1035
+ #----------------------------
1036
+ # torch.save(sample, f'{timestep.item()}_step/up_{i}.pt')
1037
+ #----------------------------
1038
+ # 6. post-process
1039
+ if self.conv_norm_out:
1040
+ sample = self.conv_norm_out(sample)
1041
+ sample = self.conv_act(sample)
1042
+ sample = self.conv_out(sample)
1043
+
1044
+ if not return_dict:
1045
+ return (sample,)
1046
+
1047
+ return UNet2DConditionOutput(sample=sample)
1048
+ return forward
1049
+ if model.__class__.__name__ == 'UNet2DConditionModel':
1050
+ model.forward = normal_forward(model)
1051
+
1052
+ def refister_time(unet, t):
1053
+ setattr(unet, 'order', t)
1054
+
1055
+
1056
+
1057
+ def register_controlnet_pipeline2(pipe):
1058
+ def new_call(self):
1059
+ @torch.no_grad()
1060
+ # @replace_example_docstring(EXAMPLE_DOC_STRING)
1061
+ def call(
1062
+ prompt: Union[str, List[str]] = None,
1063
+ image: Union[
1064
+ torch.FloatTensor,
1065
+ PIL.Image.Image,
1066
+ np.ndarray,
1067
+ List[torch.FloatTensor],
1068
+ List[PIL.Image.Image],
1069
+ List[np.ndarray],
1070
+ ] = None,
1071
+ height: Optional[int] = None,
1072
+ width: Optional[int] = None,
1073
+ num_inference_steps: int = 50,
1074
+ guidance_scale: float = 7.5,
1075
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1076
+ num_images_per_prompt: Optional[int] = 1,
1077
+ eta: float = 0.0,
1078
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1079
+ latents: Optional[torch.FloatTensor] = None,
1080
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1081
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1082
+ output_type: Optional[str] = "pil",
1083
+ return_dict: bool = True,
1084
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1085
+ callback_steps: int = 1,
1086
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1087
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
1088
+ guess_mode: bool = False,
1089
+ ):
1090
+ # 1. Check inputs. Raise error if not correct
1091
+ self.check_inputs(
1092
+ prompt,
1093
+ image,
1094
+ callback_steps,
1095
+ negative_prompt,
1096
+ prompt_embeds,
1097
+ negative_prompt_embeds,
1098
+ controlnet_conditioning_scale,
1099
+ )
1100
+
1101
+ # 2. Define call parameters
1102
+ if prompt is not None and isinstance(prompt, str):
1103
+ batch_size = 1
1104
+ elif prompt is not None and isinstance(prompt, list):
1105
+ batch_size = len(prompt)
1106
+ else:
1107
+ batch_size = prompt_embeds.shape[0]
1108
+
1109
+ device = self._execution_device
1110
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1111
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1112
+ # corresponds to doing no classifier free guidance.
1113
+ do_classifier_free_guidance = guidance_scale > 1.0
1114
+
1115
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1116
+
1117
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1118
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1119
+
1120
+ global_pool_conditions = (
1121
+ controlnet.config.global_pool_conditions
1122
+ if isinstance(controlnet, ControlNetModel)
1123
+ else controlnet.nets[0].config.global_pool_conditions
1124
+ )
1125
+ guess_mode = guess_mode or global_pool_conditions
1126
+
1127
+ # 3. Encode input prompt
1128
+ text_encoder_lora_scale = (
1129
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1130
+ )
1131
+ prompt_embeds = self._encode_prompt(
1132
+ prompt,
1133
+ device,
1134
+ num_images_per_prompt,
1135
+ do_classifier_free_guidance,
1136
+ negative_prompt,
1137
+ prompt_embeds=prompt_embeds,
1138
+ negative_prompt_embeds=negative_prompt_embeds,
1139
+ lora_scale=text_encoder_lora_scale,
1140
+ )
1141
+
1142
+ # 4. Prepare image
1143
+ if isinstance(controlnet, ControlNetModel):
1144
+ image = self.prepare_image(
1145
+ image=image,
1146
+ width=width,
1147
+ height=height,
1148
+ batch_size=batch_size * num_images_per_prompt,
1149
+ num_images_per_prompt=num_images_per_prompt,
1150
+ device=device,
1151
+ dtype=controlnet.dtype,
1152
+ do_classifier_free_guidance=do_classifier_free_guidance,
1153
+ guess_mode=guess_mode,
1154
+ )
1155
+ height, width = image.shape[-2:]
1156
+ elif isinstance(controlnet, MultiControlNetModel):
1157
+ images = []
1158
+
1159
+ for image_ in image:
1160
+ image_ = self.prepare_image(
1161
+ image=image_,
1162
+ width=width,
1163
+ height=height,
1164
+ batch_size=batch_size * num_images_per_prompt,
1165
+ num_images_per_prompt=num_images_per_prompt,
1166
+ device=device,
1167
+ dtype=controlnet.dtype,
1168
+ do_classifier_free_guidance=do_classifier_free_guidance,
1169
+ guess_mode=guess_mode,
1170
+ )
1171
+
1172
+ images.append(image_)
1173
+
1174
+ image = images
1175
+ height, width = image[0].shape[-2:]
1176
+ else:
1177
+ assert False
1178
+
1179
+ # 5. Prepare timesteps
1180
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1181
+ timesteps = self.scheduler.timesteps
1182
+
1183
+ # 6. Prepare latent variables
1184
+ num_channels_latents = self.unet.config.in_channels
1185
+ latents = self.prepare_latents(
1186
+ batch_size * num_images_per_prompt,
1187
+ num_channels_latents,
1188
+ height,
1189
+ width,
1190
+ prompt_embeds.dtype,
1191
+ device,
1192
+ generator,
1193
+ latents,
1194
+ )
1195
+ self.init_latent = latents.detach().clone()
1196
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1197
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1198
+
1199
+ # 8. Denoising loop
1200
+ #-------------------------------------------------------------
1201
+ all_steps = len(self.scheduler.timesteps)
1202
+ curr_span = 1
1203
+ curr_step = 0
1204
+
1205
+ # st = time.time()
1206
+ idx = 1
1207
+ keytime = [0,1,2,3,5,10,15,25,35,50]
1208
+
1209
+ while curr_step<all_steps:
1210
+ # torch.cuda.empty_cache()
1211
+ # print(curr_step)
1212
+ refister_time(self.unet, curr_step)
1213
+
1214
+ merge_span = curr_span
1215
+ if merge_span>0:
1216
+ time_ls = []
1217
+ for i in range(curr_step, curr_step+merge_span):
1218
+ if i<all_steps:
1219
+ time_ls.append(self.scheduler.timesteps[i])
1220
+ else:
1221
+ break
1222
+ # torch.cuda.empty_cache()
1223
+
1224
+ ##--------------------------------
1225
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1226
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_ls[0])
1227
+
1228
+ if curr_step in [0,1,2,3,5,10,15,25,35]:
1229
+ # controlnet(s) inference
1230
+ control_model_input = latent_model_input
1231
+ controlnet_prompt_embeds = prompt_embeds
1232
+
1233
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1234
+ control_model_input,
1235
+ time_ls[0],
1236
+ encoder_hidden_states=controlnet_prompt_embeds,
1237
+ controlnet_cond=image,
1238
+ conditioning_scale=controlnet_conditioning_scale,
1239
+ guess_mode=guess_mode,
1240
+ return_dict=False,
1241
+ )
1242
+
1243
+
1244
+ #----------------------save controlnet feature-------------------------
1245
+ #useless, shoule delete
1246
+ # setattr(self, 'downres_samples', deepcopy(down_block_res_samples))
1247
+ # setattr(self, 'midres_sample', mid_block_res_sample.detach().clone())
1248
+ #-----------------------save controlnet feature------------------------
1249
+ else:
1250
+ down_block_res_samples = None #self.downres_samples
1251
+ mid_block_res_sample = None #self.midres_sample
1252
+ # predict the noise residual
1253
+ noise_pred = self.unet(
1254
+ latent_model_input,
1255
+ time_ls,
1256
+ encoder_hidden_states=prompt_embeds,
1257
+ cross_attention_kwargs=cross_attention_kwargs,
1258
+ down_block_additional_residuals=down_block_res_samples,
1259
+ mid_block_additional_residual=mid_block_res_sample,
1260
+ return_dict=False,
1261
+ )[0]
1262
+
1263
+ # perform guidance
1264
+ if do_classifier_free_guidance:
1265
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1266
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1267
+
1268
+ # compute the previous noisy sample x_t -> x_t-1
1269
+
1270
+ if isinstance(time_ls, list):
1271
+ step_span = len(time_ls)
1272
+ bs = noise_pred.shape[0]
1273
+ bs_perstep = bs//step_span
1274
+
1275
+ denoised_latent = latents
1276
+ for i, timestep in enumerate(time_ls):
1277
+ curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep]
1278
+ denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0]
1279
+
1280
+ latents = denoised_latent
1281
+ ##----------------------------------------
1282
+ curr_step += curr_span
1283
+ idx += 1
1284
+ if curr_step<all_steps:
1285
+ curr_span = keytime[idx] - keytime[idx-1]
1286
+
1287
+ # for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="Sampling")):
1288
+
1289
+ #-------------------------------------------------------------
1290
+
1291
+
1292
+ # If we do sequential model offloading, let's offload unet and controlnet
1293
+ # manually for max memory savings
1294
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1295
+ self.unet.to("cpu")
1296
+ self.controlnet.to("cpu")
1297
+ torch.cuda.empty_cache()
1298
+
1299
+ if not output_type == "latent":
1300
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1301
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1302
+ else:
1303
+ image = latents
1304
+ has_nsfw_concept = None
1305
+
1306
+ if has_nsfw_concept is None:
1307
+ do_denormalize = [True] * image.shape[0]
1308
+ else:
1309
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1310
+
1311
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1312
+
1313
+ # Offload last model to CPU
1314
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1315
+ self.final_offload_hook.offload()
1316
+
1317
+ if not return_dict:
1318
+ return (image, has_nsfw_concept)
1319
+
1320
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1321
+ return call
1322
+ pipe.call = new_call(pipe)
1323
+
1324
+ @torch.no_grad()
1325
+ def multistep_pre(self, noise_pred, t, x):
1326
+ step_span = len(t)
1327
+ bs = noise_pred.shape[0]
1328
+ bs_perstep = bs//step_span
1329
+
1330
+ denoised_latent = x
1331
+ for i, timestep in enumerate(t):
1332
+ curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep]
1333
+ denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent)['prev_sample']
1334
+ return denoised_latent
1335
+
1336
+ def register_t2v(model):
1337
+ def new_back(self):
1338
+ def backward_loop(
1339
+ latents,
1340
+ timesteps,
1341
+ prompt_embeds,
1342
+ guidance_scale,
1343
+ callback,
1344
+ callback_steps,
1345
+ num_warmup_steps,
1346
+ extra_step_kwargs,
1347
+ cross_attention_kwargs=None,):
1348
+ do_classifier_free_guidance = guidance_scale > 1.0
1349
+ num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order
1350
+ import time
1351
+ if num_steps<10:
1352
+ with self.progress_bar(total=num_steps) as progress_bar:
1353
+ for i, t in enumerate(timesteps):
1354
+ setattr(self.unet, 'order', i)
1355
+ # expand the latents if we are doing classifier free guidance
1356
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1357
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1358
+
1359
+ # predict the noise residual
1360
+ noise_pred = self.unet(
1361
+ latent_model_input,
1362
+ t,
1363
+ encoder_hidden_states=prompt_embeds,
1364
+ cross_attention_kwargs=cross_attention_kwargs,
1365
+ ).sample
1366
+
1367
+ # perform guidance
1368
+ if do_classifier_free_guidance:
1369
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1370
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1371
+
1372
+ # compute the previous noisy sample x_t -> x_t-1
1373
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1374
+
1375
+ # call the callback, if provided
1376
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1377
+ progress_bar.update()
1378
+ if callback is not None and i % callback_steps == 0:
1379
+ step_idx = i // getattr(self.scheduler, "order", 1)
1380
+ callback(step_idx, t, latents)
1381
+
1382
+ else:
1383
+ all_timesteps = len(timesteps)
1384
+ curr_step = 0
1385
+
1386
+ while curr_step<all_timesteps:
1387
+ refister_time(self.unet, curr_step)
1388
+
1389
+ time_ls = []
1390
+ time_ls.append(timesteps[curr_step])
1391
+ curr_step += 1
1392
+ cond = curr_step in [0,1,2,3,5,10,15,25,35]
1393
+
1394
+ while (not cond) and (curr_step<all_timesteps):
1395
+ time_ls.append(timesteps[curr_step])
1396
+ curr_step += 1
1397
+ cond = curr_step in [0,1,2,3,5,10,15,25,35]
1398
+
1399
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1400
+ # predict the noise residual
1401
+ noise_pred = self.unet(
1402
+ latent_model_input,
1403
+ time_ls,
1404
+ encoder_hidden_states=prompt_embeds,
1405
+ cross_attention_kwargs=cross_attention_kwargs,
1406
+ ).sample
1407
+
1408
+ # perform guidance
1409
+ if do_classifier_free_guidance:
1410
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1411
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1412
+
1413
+ # compute the previous noisy sample x_t -> x_t-1
1414
+ latents = multistep_pre(self, noise_pred, time_ls, latents)
1415
+
1416
+ return latents.clone().detach()
1417
+ return backward_loop
1418
+ model.backward_loop = new_back(model)
1419
+