3v324v23 commited on
Commit
0b15b54
1 Parent(s): 9b7b0f8

Add application file

Browse files
pipline_StableDiffusion_ConsistentID.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
2
+ import cv2
3
+ import PIL
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision import transforms
8
+ from insightface.app import FaceAnalysis
9
+ from safetensors import safe_open
10
+ from huggingface_hub.utils import validate_hf_hub_args
11
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
14
+ from diffusers.utils import _get_model_file
15
+ from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
16
+ from functions import ProjPlusModel, masks_for_unique_values
17
+ from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
18
+ from modelscope.outputs import OutputKeys
19
+ from modelscope.pipelines import pipeline
20
+
21
+ #TODO 引入BiSeNet库路径
22
+ import sys
23
+ sys.path.append("./models/BiSeNet")
24
+ from model import BiSeNet
25
+
26
+
27
+
28
+ PipelineImageInput = Union[
29
+ PIL.Image.Image,
30
+ torch.FloatTensor,
31
+ List[PIL.Image.Image],
32
+ List[torch.FloatTensor],
33
+ ]
34
+
35
+
36
+ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
37
+
38
+ @validate_hf_hub_args
39
+ def load_ConsistentID_model(
40
+ self,
41
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
42
+ weight_name: str,
43
+ subfolder: str = '',
44
+ trigger_word_ID: str = '<|image|>',
45
+ trigger_word_facial: str = '<|facial|>',
46
+ image_encoder_path: str = '/data2/huangjiehui_m22/pretrained_model/CLIP-ViT-H-14-laion2B-s32B-b79K', # TODO CLIP路径
47
+ torch_dtype = torch.float16,
48
+ num_tokens = 4,
49
+ lora_rank= 128,
50
+ **kwargs,
51
+ ):
52
+ self.lora_rank = lora_rank
53
+ self.torch_dtype = torch_dtype
54
+ self.num_tokens = num_tokens
55
+ self.set_ip_adapter()
56
+ self.image_encoder_path = image_encoder_path
57
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
58
+ self.device, dtype=self.torch_dtype
59
+ )
60
+ self.clip_image_processor = CLIPImageProcessor()
61
+ self.id_image_processor = CLIPImageProcessor()
62
+ self.crop_size = 512
63
+
64
+ # FaceID
65
+ self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
66
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
67
+
68
+ ### BiSeNet
69
+ self.bise_net = BiSeNet(n_classes = 19)
70
+ self.bise_net.cuda()
71
+ self.bise_net_cp='./models/BiSeNet_pretrained_for_ConsistentID.pth' #TODO BiSeNet的checkpoint
72
+ self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
73
+ self.bise_net.eval()
74
+ # Colors for all 20 parts
75
+ self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
76
+ [255, 0, 85], [255, 0, 170],
77
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
78
+ [0, 255, 85], [0, 255, 170],
79
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
80
+ [0, 85, 255], [0, 170, 255],
81
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
82
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
83
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
84
+
85
+ ### LLVA Optional
86
+ self.llva_model_path = "/data6/huangjiehui_m22/pretrained_model/llava-v1.5-7b" #TODO llava模型路径
87
+ self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth."
88
+ self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path)
89
+
90
+ self.image_proj_model = ProjPlusModel(
91
+ cross_attention_dim=self.unet.config.cross_attention_dim,
92
+ id_embeddings_dim=512,
93
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
94
+ num_tokens=self.num_tokens, # 4
95
+ ).to(self.device, dtype=self.torch_dtype)
96
+ self.FacialEncoder = FacialEncoder(self.image_encoder).to(self.device, dtype=self.torch_dtype)
97
+
98
+ # Modelscope 美肤用
99
+ self.skin_retouching = pipeline('skin-retouching-torch', model='damo/cv_unet_skin_retouching_torch', model_revision='v1.0.2')
100
+
101
+ # Load the main state dict first.
102
+ cache_dir = kwargs.pop("cache_dir", None)
103
+ force_download = kwargs.pop("force_download", False)
104
+ resume_download = kwargs.pop("resume_download", False)
105
+ proxies = kwargs.pop("proxies", None)
106
+ local_files_only = kwargs.pop("local_files_only", None)
107
+ token = kwargs.pop("token", None)
108
+ revision = kwargs.pop("revision", None)
109
+
110
+ user_agent = {
111
+ "file_type": "attn_procs_weights",
112
+ "framework": "pytorch",
113
+ }
114
+
115
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
116
+ model_file = _get_model_file(
117
+ pretrained_model_name_or_path_or_dict,
118
+ weights_name=weight_name,
119
+ cache_dir=cache_dir,
120
+ force_download=force_download,
121
+ resume_download=resume_download,
122
+ proxies=proxies,
123
+ local_files_only=local_files_only,
124
+ use_auth_token=token,
125
+ revision=revision,
126
+ subfolder=subfolder,
127
+ user_agent=user_agent,
128
+ )
129
+ if weight_name.endswith(".safetensors"):
130
+ state_dict = {"id_encoder": {}, "lora_weights": {}}
131
+ with safe_open(model_file, framework="pt", device="cpu") as f:
132
+ for key in f.keys():
133
+ if key.startswith("id_encoder."):
134
+ state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key)
135
+ elif key.startswith("lora_weights."):
136
+ state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key)
137
+ else:
138
+ state_dict = torch.load(model_file, map_location="cpu")
139
+ else:
140
+ state_dict = pretrained_model_name_or_path_or_dict
141
+
142
+ self.trigger_word_ID = trigger_word_ID
143
+ self.trigger_word_facial = trigger_word_facial
144
+
145
+ self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True)
146
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
147
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
148
+ ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True)
149
+ print(f"Successfully loaded weights from checkpoint")
150
+
151
+ # Add trigger word token
152
+ if self.tokenizer is not None:
153
+ self.tokenizer.add_tokens([self.trigger_word_ID], special_tokens=True)
154
+ self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True)
155
+
156
+ def set_ip_adapter(self):
157
+ unet = self.unet
158
+ attn_procs = {}
159
+ for name in unet.attn_processors.keys():
160
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
161
+ if name.startswith("mid_block"):
162
+ hidden_size = unet.config.block_out_channels[-1]
163
+ elif name.startswith("up_blocks"):
164
+ block_id = int(name[len("up_blocks.")])
165
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
166
+ elif name.startswith("down_blocks"):
167
+ block_id = int(name[len("down_blocks.")])
168
+ hidden_size = unet.config.block_out_channels[block_id]
169
+ if cross_attention_dim is None:
170
+ attn_procs[name] = Consistent_AttProcessor(
171
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
172
+ ).to(self.device, dtype=self.torch_dtype)
173
+ else:
174
+ attn_procs[name] = Consistent_IPAttProcessor(
175
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
176
+ ).to(self.device, dtype=self.torch_dtype)
177
+
178
+ unet.set_attn_processor(attn_procs)
179
+
180
+ @torch.inference_mode()
181
+ def get_facial_embeds(self, prompt_embeds, negative_prompt_embeds, facial_clip_images, facial_token_masks, valid_facial_token_idx_mask):
182
+
183
+ hidden_states = []
184
+ uncond_hidden_states = []
185
+ for facial_clip_image in facial_clip_images:
186
+ # 分别把这几个裁剪出来的五官局部照用CLIP提一次
187
+ hidden_state = self.image_encoder(facial_clip_image.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2]
188
+ uncond_hidden_state = self.image_encoder(torch.zeros_like(facial_clip_image, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2]
189
+ hidden_states.append(hidden_state)
190
+ uncond_hidden_states.append(uncond_hidden_state)
191
+ multi_facial_embeds = torch.stack(hidden_states)
192
+ uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
193
+
194
+ # condition 这个关键!FacialEncoder怎么设计的
195
+ facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
196
+
197
+ # uncondition
198
+ uncond_facial_prompt_embeds = self.FacialEncoder(negative_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
199
+
200
+ return facial_prompt_embeds, uncond_facial_prompt_embeds
201
+
202
+ @torch.inference_mode()
203
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut=False):
204
+
205
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
206
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
207
+ # 先处理,变成1x3x224x224的clip_image,然后用图像编码器,编成1x257x1280
208
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
209
+ uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
210
+ # uncond_clip_image_embeds居然是用零矩阵编码出来的,用来做什么呢?用来cf guidence
211
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
212
+ image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
213
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
214
+ # image_prompt_tokens感觉像是faceID与图片在CLIP提取的特征做过注意力之后的faceid_embeds
215
+ # 而uncond_image_prompt_embeds像是假faceID与假图片在CLIP提取的特征做注意力
216
+ return image_prompt_tokens, uncond_image_prompt_embeds
217
+
218
+ def set_scale(self, scale):
219
+ for attn_processor in self.pipe.unet.attn_processors.values():
220
+ if isinstance(attn_processor, Consistent_IPAttProcessor):
221
+ attn_processor.scale = scale
222
+
223
+ @torch.inference_mode()
224
+ def get_prepare_faceid(self, face_image):
225
+ faceid_image = np.array(face_image)
226
+ # 下面这句是用insightmodel获取faceid
227
+ faces = self.app.get(faceid_image)
228
+ if faces==[]:
229
+ faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
230
+ else:# 这个insightmodel获得的是512的embedding,转成torch,头部加一维
231
+ faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
232
+ # 可能获取不出来,那么他会是一个空的ID
233
+ return faceid_embeds
234
+
235
+ @torch.inference_mode()
236
+ def parsing_face_mask(self, raw_image_refer):
237
+
238
+ to_tensor = transforms.Compose([
239
+ transforms.ToTensor(),
240
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
241
+ ])
242
+ to_pil = transforms.ToPILImage()
243
+
244
+ with torch.no_grad():
245
+ image = raw_image_refer.resize((512, 512), Image.BILINEAR)
246
+ image_resize_PIL = image
247
+ img = to_tensor(image)
248
+ img = torch.unsqueeze(img, 0)
249
+ img = img.float().cuda()
250
+ out = self.bise_net(img)[0] #1,19,512,512
251
+ parsing_anno = out.squeeze(0).cpu().numpy().argmax(0) #512,512 每个位置上是19个通道谁最大
252
+
253
+ im = np.array(image_resize_PIL)
254
+ vis_im = im.copy().astype(np.uint8)
255
+ stride=1
256
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
257
+ #下句我就不明白了,一比一缩放插值一下,不是没什么用嘛
258
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
259
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
260
+
261
+ num_of_class = np.max(vis_parsing_anno)
262
+
263
+ for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16
264
+ index = np.where(vis_parsing_anno == pi)
265
+ vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
266
+
267
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) #染了色的mask,只有颜色
268
+ vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)#与原图叠加一下
269
+
270
+ return vis_parsing_anno_color, vis_parsing_anno
271
+
272
+ @torch.inference_mode()
273
+ def get_prepare_llva_caption(self, input_image_file, model_path=None, prompt=None):
274
+
275
+ ### Optional: Use the LLaVA
276
+ # args = type('Args', (), {
277
+ # "model_path": self.llva_model_path,
278
+ # "model_base": None,
279
+ # "model_name": get_model_name_from_path(self.llva_model_path),
280
+ # "query": self.llva_prompt,
281
+ # "conv_mode": None,
282
+ # "image_file": input_image_file,
283
+ # "sep": ",",
284
+ # "temperature": 0,
285
+ # "top_p": None,
286
+ # "num_beams": 1,
287
+ # "max_new_tokens": 512
288
+ # })()
289
+ # face_caption = eval_model(args, self.llva_tokenizer, self.llva_model, self.llva_image_processor)
290
+
291
+ ### Use built-in template
292
+ face_caption = "The person has one nose, two eyes, two ears, and a mouth."
293
+
294
+ return face_caption
295
+
296
+
297
+
298
+ @torch.inference_mode()
299
+ def get_prepare_facemask(self, input_image_file):
300
+ #先获取一下mask
301
+ vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file)
302
+ parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
303
+
304
+ key_parsing_mask_list = {}
305
+ key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
306
+ # TODO 背景信息还没有用上,看看有没有必要
307
+
308
+ processed_keys = set()
309
+ for key, mask_image in parsing_mask_list.items():
310
+ if key in key_list:
311
+ if "_" in key:
312
+ prefix = key.split("_")[1]
313
+ if prefix in processed_keys: # 左耳右耳只处理一次?都是耳朵再遇到就不处理了?
314
+ continue
315
+ else:
316
+ key_parsing_mask_list[key] = mask_image
317
+ processed_keys.add(prefix)
318
+
319
+ key_parsing_mask_list[key] = mask_image
320
+
321
+ return key_parsing_mask_list, vis_parsing_anno_color
322
+
323
+ def encode_prompt_with_trigger_word(
324
+ self,
325
+ prompt: str,
326
+ face_caption: str,
327
+ key_parsing_mask_list = None,
328
+ image_token = "<|image|>",
329
+ facial_token = "<|facial|>",
330
+ max_num_facials = 5,
331
+ num_id_images: int = 1,
332
+ device: Optional[torch.device] = None,
333
+ ):
334
+ device = device or self._execution_device
335
+ #这一步是比较paper比较关键的步骤,但是推理过程中我怎么觉得没什么用
336
+ #就是说改这个prompt,让他的关键词顺序与key_parsing_mask_list_align中Eye Ear Nose的出现顺序一致
337
+ #并且在这些关键词后加上<|facial|>的文本标记
338
+ face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list)
339
+
340
+ # 与用户输入的prompt结合
341
+ prompt_face = prompt + "Detail:" + face_caption_align
342
+
343
+ max_text_length=330 # 如果用户输入的prompt太长了,会把包含facial关键字的face_caption_align提前到Detail:,防止之后被截断
344
+ if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, padding="max_length",truncation=False,return_tensors="pt").input_ids[0])!=77:
345
+ prompt_face = "Detail:" + face_caption_align + " Caption:" + prompt
346
+
347
+ if len(face_caption)>max_text_length:
348
+ prompt_face = prompt
349
+ face_caption_align = ""
350
+
351
+ prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
352
+ tokenizer = self.tokenizer
353
+ # level 3 设定触发词 并获取"<|facial|>"触发词 id-49409
354
+ facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
355
+ image_token_id = None # TODO level2要做的事情,这个要做什么没理解
356
+
357
+ # clean_input_id就是1x77长经典的SD用的tokens,里面没有触发词的编码
358
+ # image_token_mask是1x77长的false,好像暂时没什么用,
359
+ # facial_token_mask是1x77长的false中间有几个true,true是触发词的位置
360
+ # 还有一个问题是长度就77,怎么做prompt engineering?TODO
361
+ clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends(
362
+ prompt_face, image_token_id, facial_token_id, tokenizer)
363
+ # 下面这个也是没懂这做什么,image_token_idx好像没用,facial_token_idx有用,获得了facial token的位置索引,mask就感觉没什么用了
364
+ image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx(
365
+ image_token_mask, facial_token_mask, num_id_images, max_num_facials )
366
+
367
+ return prompt_text_only, clean_input_id, key_parsing_mask_list_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
368
+
369
+ @torch.inference_mode()
370
+ def get_prepare_clip_image(self, input_image_file, key_parsing_mask_list, image_size=512, max_num_facials=5, change_facial=True):
371
+
372
+ facial_mask = []
373
+ facial_clip_image = []
374
+ transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),])
375
+ clip_image_processor = CLIPImageProcessor()
376
+
377
+ num_facial_part = len(key_parsing_mask_list)
378
+ #这个循环就是取mask与原图的与,放到facial_clip_image里存着
379
+ for key in key_parsing_mask_list:
380
+ key_mask=key_parsing_mask_list[key]
381
+ facial_mask.append(transform_mask(key_mask))
382
+ # key_mask_raw_image就是按照五官的mask截取出原图的一小部分区域
383
+ key_mask_raw_image = fetch_mask_raw_image(input_image_file,key_mask)
384
+ parsing_clip_image = clip_image_processor(images=key_mask_raw_image, return_tensors="pt").pixel_values
385
+ facial_clip_image.append(parsing_clip_image)
386
+
387
+ padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
388
+ padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
389
+ # facial_clip_image与facial_mask补上到max_num_facials,感觉这个没什么用
390
+ if num_facial_part < max_num_facials:
391
+ facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
392
+ facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)]
393
+
394
+ facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0)
395
+ facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1)
396
+
397
+ return facial_clip_image, facial_mask # facial_mask 在训练过程中是用来做 loss 的, 推理过程不需要
398
+
399
+ # pipe入口是这里
400
+ @torch.no_grad()
401
+ def __call__(
402
+ self,
403
+ prompt: Union[str, List[str]] = None,
404
+ height: Optional[int] = None,
405
+ width: Optional[int] = None,
406
+ num_inference_steps: int = 50,
407
+ guidance_scale: float = 5.0,
408
+ negative_prompt: Optional[Union[str, List[str]]] = None,
409
+ num_images_per_prompt: Optional[int] = 1,
410
+ eta: float = 0.0,
411
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
412
+ latents: Optional[torch.FloatTensor] = None,
413
+ prompt_embeds: Optional[torch.FloatTensor] = None,
414
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
415
+ output_type: Optional[str] = "pil",
416
+ return_dict: bool = True,
417
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
418
+ original_size: Optional[Tuple[int, int]] = None,
419
+ target_size: Optional[Tuple[int, int]] = None,
420
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
421
+ callback_steps: int = 1,
422
+ input_id_images: PipelineImageInput = None,
423
+ reference_id_images: PipelineImageInput =None,
424
+ start_merge_step: int = 0,
425
+ class_tokens_mask: Optional[torch.LongTensor] = None,
426
+ prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
427
+ retouching: bool=False,
428
+ need_safetycheck: bool=True,
429
+ ):
430
+ # 0. Default height and width to unet
431
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
432
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
433
+
434
+ original_size = original_size or (height, width)
435
+ target_size = target_size or (height, width)
436
+
437
+ # 1. Check inputs. Raise error if not correct
438
+ self.check_inputs(
439
+ prompt,
440
+ height,
441
+ width,
442
+ callback_steps,
443
+ negative_prompt,
444
+ prompt_embeds,
445
+ negative_prompt_embeds,
446
+ )
447
+ if not isinstance(input_id_images, list):
448
+ input_id_images = [input_id_images]
449
+
450
+ # 2. Define call parameters
451
+ if prompt is not None and isinstance(prompt, str):
452
+ batch_size = 1
453
+ elif prompt is not None and isinstance(prompt, list):
454
+ batch_size = len(prompt) #TODO
455
+ else:
456
+ batch_size = prompt_embeds.shape[0]
457
+
458
+ device = self._execution_device
459
+ do_classifier_free_guidance = guidance_scale >= 1.0
460
+ input_image_file = input_id_images[0]
461
+
462
+ # *************4-14,使用多照片的混合faceid,发现并没有很大影响
463
+ if reference_id_images:
464
+ references_faceid_embeds=[]
465
+ for reference_image in reference_id_images:
466
+ references_faceid_embeds.append(self.get_prepare_faceid(face_image=reference_image))
467
+ references_faceid_embeds = torch.stack(references_faceid_embeds, dim=0) #torch.Size([16, 1, 512])
468
+ references_faceid_embeds_mean=torch.mean(references_faceid_embeds, dim=0)
469
+ # references_faceid_embeds_var=torch.var(references_faceid_embeds, dim=0)
470
+ # references_faceid_embeds_sample=torch.normal(references_faceid_embeds_mean, references_faceid_embeds_var)
471
+
472
+ # 这里不是很理解,faceid推理时是哪里来的
473
+ # insightface提的,1X512
474
+ faceid_embeds = self.get_prepare_faceid(face_image=input_image_file) #TODO 用gradio的时候打开这句关掉下句
475
+ # faceid_embeds = references_faceid_embeds_mean # 用参考人像集中的采样来做id
476
+ # 推理的时候没有用到llava的detailed面部描述嘛?
477
+ # 无
478
+ face_caption = self.get_prepare_llva_caption(input_image_file)
479
+ # 问题有,没识别到左眼左耳;这右耳的mask实在是太小了,聊胜于无,
480
+ key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file)
481
+
482
+ # 这个是断言语句,就是guidance_scale >= 1.0时继续允许,否则抛出报错
483
+ assert do_classifier_free_guidance
484
+
485
+ # 3. Encode input prompt
486
+ num_id_images = len(input_id_images)
487
+
488
+ (
489
+ prompt_text_only, # 用户输入的prompt与预制的prompt的拼接,没有facial关键词
490
+ clean_input_id, # 对prompt_text_only token化得到
491
+ key_parsing_mask_list_align, #似乎只有眼睛,耳朵,鼻子,嘴,是数组,里面放了四个PIL的mask
492
+ facial_token_mask, # 大部分False小部分
493
+ facial_token_idx, # 似乎就这个有用,得到了facial token在tokens中的索引,是一个5长度的数组
494
+ facial_token_idx_mask,
495
+ ) = self.encode_prompt_with_trigger_word(
496
+ prompt = prompt,
497
+ face_caption = face_caption,#这个是固定的
498
+ key_parsing_mask_list=key_parsing_mask_list,
499
+ device=device,
500
+ max_num_facials = 5,
501
+ num_id_images= num_id_images,
502
+ # prompt_embeds= None,
503
+ # pooled_prompt_embeds= None,
504
+ # class_tokens_mask= None,
505
+ )
506
+
507
+ # 4. Encode input prompt without the trigger word for delayed conditioning
508
+ encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
509
+ # 上面这玩意把clean_input_id encoder了一下,变成了1x77x768张量
510
+ # 这个就是CLIP的encoder,没有做修改。
511
+ prompt_embeds = self._encode_prompt(
512
+ prompt_text_only,
513
+ device=device,
514
+ num_images_per_prompt=num_images_per_prompt,
515
+ do_classifier_free_guidance=True,
516
+ negative_prompt=negative_prompt, #这个函数是SD的pipeline自带的,
517
+ ) #这玩意是2x77x768,第一份给了negative_encoder_hidden_states_text_only,第二份给了encoder_hidden_states_text_only
518
+ negative_encoder_hidden_states_text_only = prompt_embeds[0:num_images_per_prompt]
519
+ encoder_hidden_states_text_only = prompt_embeds[num_images_per_prompt:]
520
+
521
+ # 5. Prepare the input ID images
522
+ prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=0.0, shortcut=True)
523
+ # TODO s_scale这里被我改了,改成faceid_embeds的残差,试了一下,不行,动不了这个模块,后面的参数已经
524
+ # 上面两个编码完之后是1x4x768,prompt_tokens_faceid是与整个图像做完注意力的faceid
525
+ # uncond_prompt_tokens_faceid,之所以要这个uncond的,是CF guidance的公式需要,需要保留一定的多样性,
526
+ facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=512, max_num_facials=5)
527
+ # 上面这两个处理完是5x3x224x224,5x512x512,推理只用到facial_clip_image,就是原图与mask的与,并且所放到了224x224
528
+ facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype)
529
+ # 这里,有必要把facial_clip_images的图片印出来看看,看看覆盖面积大不大
530
+ facial_token_mask = facial_token_mask.to(device)
531
+ facial_token_idx_mask = facial_token_idx_mask.to(device)
532
+ negative_encoder_hidden_states = negative_encoder_hidden_states_text_only
533
+
534
+ cross_attention_kwargs = {}
535
+
536
+ # 6. Get the update text embeddingx
537
+ prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_encoder_hidden_states, \
538
+ facial_clip_images, facial_token_mask, facial_token_idx_mask)
539
+ # prompt_embeds_facial本是textprompt,在标记位用五官局部图的CLIP特征做了替换
540
+ # prompt_tokens_faceid本是insightface提取的特征的,也就是faceid eb,再融合了用全图的CLIP特征,融合的时候有注意力机制
541
+ prompt_embeds = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1)
542
+ negative_prompt_embeds = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1)
543
+ # TODO 这一步没懂啊,都已经获得prompt_embeds了,怎么又过了一次_encode_prompt,这个是SD自带的交叉注意力机制,具体内部是怎么样的还没看
544
+ prompt_embeds = self._encode_prompt(
545
+ prompt,
546
+ device,
547
+ num_images_per_prompt,
548
+ do_classifier_free_guidance,
549
+ negative_prompt,
550
+ prompt_embeds=prompt_embeds,
551
+ negative_prompt_embeds=negative_prompt_embeds,
552
+ )
553
+ # 从SD这个出来的出来得到的prompt_embeds torch.Size([2, 81, 768]),我猜就是在第一维把有无条件的两个prompt_embeds cat了一下
554
+ # 下面这两句后prompt_embeds torch.Size([3, 81, 768]),又在后面第一维加了纯文本与faceid eb的融合
555
+ prompt_embeds_text_only = torch.cat([encoder_hidden_states_text_only, prompt_tokens_faceid], dim=1)
556
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_text_only], dim=0)
557
+
558
+ # 7. Prepare timesteps
559
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
560
+ timesteps = self.scheduler.timesteps
561
+
562
+ # 8. Prepare latent variables
563
+ num_channels_latents = self.unet.in_channels
564
+ latents = self.prepare_latents(
565
+ batch_size * num_images_per_prompt,
566
+ num_channels_latents,
567
+ height,
568
+ width,
569
+ prompt_embeds.dtype,
570
+ device,
571
+ generator,
572
+ latents,
573
+ )
574
+
575
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
576
+ (
577
+ null_prompt_embeds, #无条件prompt 也是negative prompt
578
+ augmented_prompt_embeds, #增强的文本prompt+ id prompt
579
+ text_prompt_embeds, #文本prompt+id prompt
580
+ ) = prompt_embeds.chunk(3)
581
+
582
+ # 9. Denoising loop
583
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
584
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
585
+ for i, t in enumerate(timesteps):
586
+ latent_model_input = (
587
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
588
+ )
589
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
590
+
591
+ if i <= start_merge_step:
592
+ current_prompt_embeds = torch.cat(
593
+ [null_prompt_embeds, text_prompt_embeds], dim=0
594
+ )
595
+ else:
596
+ current_prompt_embeds = torch.cat(
597
+ [null_prompt_embeds, augmented_prompt_embeds], dim=0
598
+ )
599
+
600
+ # predict the noise residual 这一步魔改了一点东西
601
+ noise_pred = self.unet(
602
+ latent_model_input,
603
+ t,
604
+ encoder_hidden_states=current_prompt_embeds,
605
+ cross_attention_kwargs=cross_attention_kwargs,
606
+ ).sample
607
+
608
+ # perform guidance
609
+ if do_classifier_free_guidance:
610
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
611
+ noise_pred = noise_pred_uncond + guidance_scale * (
612
+ noise_pred_text - noise_pred_uncond
613
+ )
614
+ else:
615
+ assert 0, "Not Implemented"
616
+
617
+ # compute the previous noisy sample x_t -> x_t-1
618
+ latents = self.scheduler.step(
619
+ noise_pred, t, latents, **extra_step_kwargs
620
+ ).prev_sample
621
+
622
+ # call the callback, if provided
623
+ if i == len(timesteps) - 1 or (
624
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
625
+ ):
626
+ progress_bar.update()
627
+ if callback is not None and i % callback_steps == 0:
628
+ callback(i, t, latents)
629
+
630
+ if output_type == "latent":
631
+ image = latents
632
+ has_nsfw_concept = None
633
+ elif output_type == "pil": #默认的
634
+ # 9.1 Post-processing
635
+ image = self.decode_latents(latents)
636
+
637
+ # 9.2 Run safety checker
638
+ if need_safetycheck:
639
+ image, has_nsfw_concept = self.run_safety_checker(
640
+ image, device, prompt_embeds.dtype
641
+ )
642
+ else:
643
+ has_nsfw_concept = None
644
+
645
+ # 9.3 Convert to PIL list
646
+ image = self.numpy_to_pil(image)
647
+
648
+ # 临时添加的,美肤效果,modelscope接收PIL对象,给一个BGR矩阵
649
+ # 用了一下还是不要了,这个美肤模型失败概率有点大
650
+ if retouching:
651
+ after_retouching = self.skin_retouching(image[0])
652
+ if OutputKeys.OUTPUT_IMG in after_retouching:
653
+ image = [Image.fromarray(cv2.cvtColor(after_retouching[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB))]
654
+ else:
655
+ # 9.1 Post-processing
656
+ image = self.decode_latents(latents)
657
+
658
+ # 9.2 Run safety checker
659
+ image, has_nsfw_concept = self.run_safety_checker(
660
+ image, device, prompt_embeds.dtype
661
+ )
662
+
663
+
664
+ # Offload last model to CPU
665
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
666
+ self.final_offload_hook.offload()
667
+
668
+ if not return_dict:
669
+ return (image, has_nsfw_concept)
670
+
671
+ return StableDiffusionPipelineOutput(
672
+ images=image, nsfw_content_detected=has_nsfw_concept
673
+ )
674
+