Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any, Callable, Dict, List, Optional, Union, Tuple | |
import cv2 | |
import PIL | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from insightface.app import FaceAnalysis | |
from safetensors import safe_open | |
from huggingface_hub.utils import validate_hf_hub_args | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline | |
from diffusers.utils import _get_model_file | |
from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx | |
from functions import ProjPlusModel, masks_for_unique_values | |
from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder | |
from modelscope.outputs import OutputKeys | |
from modelscope.pipelines import pipeline | |
#TODO 引入BiSeNet库路径 | |
import sys | |
sys.path.append("./models/BiSeNet") | |
from model import BiSeNet | |
PipelineImageInput = Union[ | |
PIL.Image.Image, | |
torch.FloatTensor, | |
List[PIL.Image.Image], | |
List[torch.FloatTensor], | |
] | |
class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline): | |
def load_ConsistentID_model( | |
self, | |
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
weight_name: str, | |
subfolder: str = '', | |
trigger_word_ID: str = '<|image|>', | |
trigger_word_facial: str = '<|facial|>', | |
image_encoder_path: str = '/data2/huangjiehui_m22/pretrained_model/CLIP-ViT-H-14-laion2B-s32B-b79K', # TODO CLIP路径 | |
torch_dtype = torch.float16, | |
num_tokens = 4, | |
lora_rank= 128, | |
**kwargs, | |
): | |
self.lora_rank = lora_rank | |
self.torch_dtype = torch_dtype | |
self.num_tokens = num_tokens | |
self.set_ip_adapter() | |
self.image_encoder_path = image_encoder_path | |
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( | |
self.device, dtype=self.torch_dtype | |
) | |
self.clip_image_processor = CLIPImageProcessor() | |
self.id_image_processor = CLIPImageProcessor() | |
self.crop_size = 512 | |
# FaceID | |
self.app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
self.app.prepare(ctx_id=0, det_size=(640, 640)) | |
### BiSeNet | |
self.bise_net = BiSeNet(n_classes = 19) | |
self.bise_net.cuda() | |
self.bise_net_cp='./models/BiSeNet_pretrained_for_ConsistentID.pth' #TODO BiSeNet的checkpoint | |
self.bise_net.load_state_dict(torch.load(self.bise_net_cp)) | |
self.bise_net.eval() | |
# Colors for all 20 parts | |
self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], | |
[255, 0, 85], [255, 0, 170], | |
[0, 255, 0], [85, 255, 0], [170, 255, 0], | |
[0, 255, 85], [0, 255, 170], | |
[0, 0, 255], [85, 0, 255], [170, 0, 255], | |
[0, 85, 255], [0, 170, 255], | |
[255, 255, 0], [255, 255, 85], [255, 255, 170], | |
[255, 0, 255], [255, 85, 255], [255, 170, 255], | |
[0, 255, 255], [85, 255, 255], [170, 255, 255]] | |
### LLVA Optional | |
self.llva_model_path = "/data6/huangjiehui_m22/pretrained_model/llava-v1.5-7b" #TODO llava模型路径 | |
self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth." | |
self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path) | |
self.image_proj_model = ProjPlusModel( | |
cross_attention_dim=self.unet.config.cross_attention_dim, | |
id_embeddings_dim=512, | |
clip_embeddings_dim=self.image_encoder.config.hidden_size, | |
num_tokens=self.num_tokens, # 4 | |
).to(self.device, dtype=self.torch_dtype) | |
self.FacialEncoder = FacialEncoder(self.image_encoder).to(self.device, dtype=self.torch_dtype) | |
# Modelscope 美肤用 | |
self.skin_retouching = pipeline('skin-retouching-torch', model='damo/cv_unet_skin_retouching_torch', model_revision='v1.0.2') | |
# Load the main state dict first. | |
cache_dir = kwargs.pop("cache_dir", None) | |
force_download = kwargs.pop("force_download", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", None) | |
token = kwargs.pop("token", None) | |
revision = kwargs.pop("revision", None) | |
user_agent = { | |
"file_type": "attn_procs_weights", | |
"framework": "pytorch", | |
} | |
if not isinstance(pretrained_model_name_or_path_or_dict, dict): | |
model_file = _get_model_file( | |
pretrained_model_name_or_path_or_dict, | |
weights_name=weight_name, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
) | |
if weight_name.endswith(".safetensors"): | |
state_dict = {"id_encoder": {}, "lora_weights": {}} | |
with safe_open(model_file, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
if key.startswith("id_encoder."): | |
state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) | |
elif key.startswith("lora_weights."): | |
state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) | |
else: | |
state_dict = torch.load(model_file, map_location="cpu") | |
else: | |
state_dict = pretrained_model_name_or_path_or_dict | |
self.trigger_word_ID = trigger_word_ID | |
self.trigger_word_facial = trigger_word_facial | |
self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True) | |
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) | |
ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) | |
ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True) | |
print(f"Successfully loaded weights from checkpoint") | |
# Add trigger word token | |
if self.tokenizer is not None: | |
self.tokenizer.add_tokens([self.trigger_word_ID], special_tokens=True) | |
self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True) | |
def set_ip_adapter(self): | |
unet = self.unet | |
attn_procs = {} | |
for name in unet.attn_processors.keys(): | |
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | |
if name.startswith("mid_block"): | |
hidden_size = unet.config.block_out_channels[-1] | |
elif name.startswith("up_blocks"): | |
block_id = int(name[len("up_blocks.")]) | |
hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | |
elif name.startswith("down_blocks"): | |
block_id = int(name[len("down_blocks.")]) | |
hidden_size = unet.config.block_out_channels[block_id] | |
if cross_attention_dim is None: | |
attn_procs[name] = Consistent_AttProcessor( | |
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank, | |
).to(self.device, dtype=self.torch_dtype) | |
else: | |
attn_procs[name] = Consistent_IPAttProcessor( | |
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens, | |
).to(self.device, dtype=self.torch_dtype) | |
unet.set_attn_processor(attn_procs) | |
def get_facial_embeds(self, prompt_embeds, negative_prompt_embeds, facial_clip_images, facial_token_masks, valid_facial_token_idx_mask): | |
hidden_states = [] | |
uncond_hidden_states = [] | |
for facial_clip_image in facial_clip_images: | |
# 分别把这几个裁剪出来的五官局部照用CLIP提一次 | |
hidden_state = self.image_encoder(facial_clip_image.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2] | |
uncond_hidden_state = self.image_encoder(torch.zeros_like(facial_clip_image, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2] | |
hidden_states.append(hidden_state) | |
uncond_hidden_states.append(uncond_hidden_state) | |
multi_facial_embeds = torch.stack(hidden_states) | |
uncond_multi_facial_embeds = torch.stack(uncond_hidden_states) | |
# condition 这个关键!FacialEncoder怎么设计的 | |
facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask) | |
# uncondition | |
uncond_facial_prompt_embeds = self.FacialEncoder(negative_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask) | |
return facial_prompt_embeds, uncond_facial_prompt_embeds | |
def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut=False): | |
clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values | |
clip_image = clip_image.to(self.device, dtype=self.torch_dtype) | |
# 先处理,变成1x3x224x224的clip_image,然后用图像编码器,编成1x257x1280 | |
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2] | |
# uncond_clip_image_embeds居然是用零矩阵编码出来的,用来做什么呢?用来cf guidence | |
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype) | |
image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale) | |
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale) | |
# image_prompt_tokens感觉像是faceID与图片在CLIP提取的特征做过注意力之后的faceid_embeds | |
# 而uncond_image_prompt_embeds像是假faceID与假图片在CLIP提取的特征做注意力 | |
return image_prompt_tokens, uncond_image_prompt_embeds | |
def set_scale(self, scale): | |
for attn_processor in self.pipe.unet.attn_processors.values(): | |
if isinstance(attn_processor, Consistent_IPAttProcessor): | |
attn_processor.scale = scale | |
def get_prepare_faceid(self, face_image): | |
faceid_image = np.array(face_image) | |
# 下面这句是用insightmodel获取faceid | |
faces = self.app.get(faceid_image) | |
if faces==[]: | |
faceid_embeds = torch.zeros_like(torch.empty((1, 512))) | |
else:# 这个insightmodel获得的是512的embedding,转成torch,头部加一维 | |
faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) | |
# 可能获取不出来,那么他会是一个空的ID | |
return faceid_embeds | |
def parsing_face_mask(self, raw_image_refer): | |
to_tensor = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
to_pil = transforms.ToPILImage() | |
with torch.no_grad(): | |
image = raw_image_refer.resize((512, 512), Image.BILINEAR) | |
image_resize_PIL = image | |
img = to_tensor(image) | |
img = torch.unsqueeze(img, 0) | |
img = img.float().cuda() | |
out = self.bise_net(img)[0] #1,19,512,512 | |
parsing_anno = out.squeeze(0).cpu().numpy().argmax(0) #512,512 每个位置上是19个通道谁最大 | |
im = np.array(image_resize_PIL) | |
vis_im = im.copy().astype(np.uint8) | |
stride=1 | |
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) | |
#下句我就不明白了,一比一缩放插值一下,不是没什么用嘛 | |
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) | |
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 | |
num_of_class = np.max(vis_parsing_anno) | |
for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16 | |
index = np.where(vis_parsing_anno == pi) | |
vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi] | |
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) #染了色的mask,只有颜色 | |
vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)#与原图叠加一下 | |
return vis_parsing_anno_color, vis_parsing_anno | |
def get_prepare_llva_caption(self, input_image_file, model_path=None, prompt=None): | |
### Optional: Use the LLaVA | |
# args = type('Args', (), { | |
# "model_path": self.llva_model_path, | |
# "model_base": None, | |
# "model_name": get_model_name_from_path(self.llva_model_path), | |
# "query": self.llva_prompt, | |
# "conv_mode": None, | |
# "image_file": input_image_file, | |
# "sep": ",", | |
# "temperature": 0, | |
# "top_p": None, | |
# "num_beams": 1, | |
# "max_new_tokens": 512 | |
# })() | |
# face_caption = eval_model(args, self.llva_tokenizer, self.llva_model, self.llva_image_processor) | |
### Use built-in template | |
face_caption = "The person has one nose, two eyes, two ears, and a mouth." | |
return face_caption | |
def get_prepare_facemask(self, input_image_file): | |
#先获取一下mask | |
vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file) | |
parsing_mask_list = masks_for_unique_values(vis_parsing_anno) | |
key_parsing_mask_list = {} | |
key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"] | |
# TODO 背景信息还没有用上,看看有没有必要 | |
processed_keys = set() | |
for key, mask_image in parsing_mask_list.items(): | |
if key in key_list: | |
if "_" in key: | |
prefix = key.split("_")[1] | |
if prefix in processed_keys: # 左耳右耳只处理一次?都是耳朵再遇到就不处理了? | |
continue | |
else: | |
key_parsing_mask_list[key] = mask_image | |
processed_keys.add(prefix) | |
key_parsing_mask_list[key] = mask_image | |
return key_parsing_mask_list, vis_parsing_anno_color | |
def encode_prompt_with_trigger_word( | |
self, | |
prompt: str, | |
face_caption: str, | |
key_parsing_mask_list = None, | |
image_token = "<|image|>", | |
facial_token = "<|facial|>", | |
max_num_facials = 5, | |
num_id_images: int = 1, | |
device: Optional[torch.device] = None, | |
): | |
device = device or self._execution_device | |
#这一步是比较paper比较关键的步骤,但是推理过程中我怎么觉得没什么用 | |
#就是说改这个prompt,让他的关键词顺序与key_parsing_mask_list_align中Eye Ear Nose的出现顺序一致 | |
#并且在这些关键词后加上<|facial|>的文本标记 | |
face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list) | |
# 与用户输入的prompt结合 | |
prompt_face = prompt + "Detail:" + face_caption_align | |
max_text_length=330 # 如果用户输入的prompt太长了,会把包含facial关键字的face_caption_align提前到Detail:,防止之后被截断 | |
if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length, padding="max_length",truncation=False,return_tensors="pt").input_ids[0])!=77: | |
prompt_face = "Detail:" + face_caption_align + " Caption:" + prompt | |
if len(face_caption)>max_text_length: | |
prompt_face = prompt | |
face_caption_align = "" | |
prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "") | |
tokenizer = self.tokenizer | |
# level 3 设定触发词 并获取"<|facial|>"触发词 id-49409 | |
facial_token_id = tokenizer.convert_tokens_to_ids(facial_token) | |
image_token_id = None # TODO level2要做的事情,这个要做什么没理解 | |
# clean_input_id就是1x77长经典的SD用的tokens,里面没有触发词的编码 | |
# image_token_mask是1x77长的false,好像暂时没什么用, | |
# facial_token_mask是1x77长的false中间有几个true,true是触发词的位置 | |
# 还有一个问题是长度就77,怎么做prompt engineering?TODO | |
clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends( | |
prompt_face, image_token_id, facial_token_id, tokenizer) | |
# 下面这个也是没懂这做什么,image_token_idx好像没用,facial_token_idx有用,获得了facial token的位置索引,mask就感觉没什么用了 | |
image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx( | |
image_token_mask, facial_token_mask, num_id_images, max_num_facials ) | |
return prompt_text_only, clean_input_id, key_parsing_mask_list_align, facial_token_mask, facial_token_idx, facial_token_idx_mask | |
def get_prepare_clip_image(self, input_image_file, key_parsing_mask_list, image_size=512, max_num_facials=5, change_facial=True): | |
facial_mask = [] | |
facial_clip_image = [] | |
transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),]) | |
clip_image_processor = CLIPImageProcessor() | |
num_facial_part = len(key_parsing_mask_list) | |
#这个循环就是取mask与原图的与,放到facial_clip_image里存着 | |
for key in key_parsing_mask_list: | |
key_mask=key_parsing_mask_list[key] | |
facial_mask.append(transform_mask(key_mask)) | |
# key_mask_raw_image就是按照五官的mask截取出原图的一小部分区域 | |
key_mask_raw_image = fetch_mask_raw_image(input_image_file,key_mask) | |
parsing_clip_image = clip_image_processor(images=key_mask_raw_image, return_tensors="pt").pixel_values | |
facial_clip_image.append(parsing_clip_image) | |
padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224])) | |
padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size])) | |
# facial_clip_image与facial_mask补上到max_num_facials,感觉这个没什么用 | |
if num_facial_part < max_num_facials: | |
facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ] | |
facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)] | |
facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0) | |
facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1) | |
return facial_clip_image, facial_mask # facial_mask 在训练过程中是用来做 loss 的, 推理过程不需要 | |
# pipe入口是这里 | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
guidance_scale: float = 5.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
original_size: Optional[Tuple[int, int]] = None, | |
target_size: Optional[Tuple[int, int]] = None, | |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | |
callback_steps: int = 1, | |
input_id_images: PipelineImageInput = None, | |
reference_id_images: PipelineImageInput =None, | |
start_merge_step: int = 0, | |
class_tokens_mask: Optional[torch.LongTensor] = None, | |
prompt_embeds_text_only: Optional[torch.FloatTensor] = None, | |
retouching: bool=False, | |
need_safetycheck: bool=True, | |
): | |
# 0. Default height and width to unet | |
height = height or self.unet.config.sample_size * self.vae_scale_factor | |
width = width or self.unet.config.sample_size * self.vae_scale_factor | |
original_size = original_size or (height, width) | |
target_size = target_size or (height, width) | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
height, | |
width, | |
callback_steps, | |
negative_prompt, | |
prompt_embeds, | |
negative_prompt_embeds, | |
) | |
if not isinstance(input_id_images, list): | |
input_id_images = [input_id_images] | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) #TODO | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
do_classifier_free_guidance = guidance_scale >= 1.0 | |
input_image_file = input_id_images[0] | |
# *************4-14,使用多照片的混合faceid,发现并没有很大影响 | |
if reference_id_images: | |
references_faceid_embeds=[] | |
for reference_image in reference_id_images: | |
references_faceid_embeds.append(self.get_prepare_faceid(face_image=reference_image)) | |
references_faceid_embeds = torch.stack(references_faceid_embeds, dim=0) #torch.Size([16, 1, 512]) | |
references_faceid_embeds_mean=torch.mean(references_faceid_embeds, dim=0) | |
# references_faceid_embeds_var=torch.var(references_faceid_embeds, dim=0) | |
# references_faceid_embeds_sample=torch.normal(references_faceid_embeds_mean, references_faceid_embeds_var) | |
# 这里不是很理解,faceid推理时是哪里来的 | |
# insightface提的,1X512 | |
faceid_embeds = self.get_prepare_faceid(face_image=input_image_file) #TODO 用gradio的时候打开这句关掉下句 | |
# faceid_embeds = references_faceid_embeds_mean # 用参考人像集中的采样来做id | |
# 推理的时候没有用到llava的detailed面部描述嘛? | |
# 无 | |
face_caption = self.get_prepare_llva_caption(input_image_file) | |
# 问题有,没识别到左眼左耳;这右耳的mask实在是太小了,聊胜于无, | |
key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file) | |
# 这个是断言语句,就是guidance_scale >= 1.0时继续允许,否则抛出报错 | |
assert do_classifier_free_guidance | |
# 3. Encode input prompt | |
num_id_images = len(input_id_images) | |
( | |
prompt_text_only, # 用户输入的prompt与预制的prompt的拼接,没有facial关键词 | |
clean_input_id, # 对prompt_text_only token化得到 | |
key_parsing_mask_list_align, #似乎只有眼睛,耳朵,鼻子,嘴,是数组,里面放了四个PIL的mask | |
facial_token_mask, # 大部分False小部分 | |
facial_token_idx, # 似乎就这个有用,得到了facial token在tokens中的索引,是一个5长度的数组 | |
facial_token_idx_mask, | |
) = self.encode_prompt_with_trigger_word( | |
prompt = prompt, | |
face_caption = face_caption,#这个是固定的 | |
key_parsing_mask_list=key_parsing_mask_list, | |
device=device, | |
max_num_facials = 5, | |
num_id_images= num_id_images, | |
# prompt_embeds= None, | |
# pooled_prompt_embeds= None, | |
# class_tokens_mask= None, | |
) | |
# 4. Encode input prompt without the trigger word for delayed conditioning | |
encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0] | |
# 上面这玩意把clean_input_id encoder了一下,变成了1x77x768张量 | |
# 这个就是CLIP的encoder,没有做修改。 | |
prompt_embeds = self._encode_prompt( | |
prompt_text_only, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
do_classifier_free_guidance=True, | |
negative_prompt=negative_prompt, #这个函数是SD的pipeline自带的, | |
) #这玩意是2x77x768,第一份给了negative_encoder_hidden_states_text_only,第二份给了encoder_hidden_states_text_only | |
negative_encoder_hidden_states_text_only = prompt_embeds[0:num_images_per_prompt] | |
encoder_hidden_states_text_only = prompt_embeds[num_images_per_prompt:] | |
# 5. Prepare the input ID images | |
prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=0.0, shortcut=True) | |
# TODO s_scale这里被我改了,改成faceid_embeds的残差,试了一下,不行,动不了这个模块,后面的参数已经 | |
# 上面两个编码完之后是1x4x768,prompt_tokens_faceid是与整个图像做完注意力的faceid | |
# uncond_prompt_tokens_faceid,之所以要这个uncond的,是CF guidance的公式需要,需要保留一定的多样性, | |
facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=512, max_num_facials=5) | |
# 上面这两个处理完是5x3x224x224,5x512x512,推理只用到facial_clip_image,就是原图与mask的与,并且所放到了224x224 | |
facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype) | |
# 这里,有必要把facial_clip_images的图片印出来看看,看看覆盖面积大不大 | |
facial_token_mask = facial_token_mask.to(device) | |
facial_token_idx_mask = facial_token_idx_mask.to(device) | |
negative_encoder_hidden_states = negative_encoder_hidden_states_text_only | |
cross_attention_kwargs = {} | |
# 6. Get the update text embeddingx | |
prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_encoder_hidden_states, \ | |
facial_clip_images, facial_token_mask, facial_token_idx_mask) | |
# prompt_embeds_facial本是textprompt,在标记位用五官局部图的CLIP特征做了替换 | |
# prompt_tokens_faceid本是insightface提取的特征的,也就是faceid eb,再融合了用全图的CLIP特征,融合的时候有注意力机制 | |
prompt_embeds = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1) | |
negative_prompt_embeds = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1) | |
# TODO 这一步没懂啊,都已经获得prompt_embeds了,怎么又过了一次_encode_prompt,这个是SD自带的交叉注意力机制,具体内部是怎么样的还没看 | |
prompt_embeds = self._encode_prompt( | |
prompt, | |
device, | |
num_images_per_prompt, | |
do_classifier_free_guidance, | |
negative_prompt, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
) | |
# 从SD这个出来的出来得到的prompt_embeds torch.Size([2, 81, 768]),我猜就是在第一维把有无条件的两个prompt_embeds cat了一下 | |
# 下面这两句后prompt_embeds torch.Size([3, 81, 768]),又在后面第一维加了纯文本与faceid eb的融合 | |
prompt_embeds_text_only = torch.cat([encoder_hidden_states_text_only, prompt_tokens_faceid], dim=1) | |
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds_text_only], dim=0) | |
# 7. Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
# 8. Prepare latent variables | |
num_channels_latents = self.unet.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) | |
( | |
null_prompt_embeds, #无条件prompt 也是negative prompt | |
augmented_prompt_embeds, #增强的文本prompt+ id prompt | |
text_prompt_embeds, #文本prompt+id prompt | |
) = prompt_embeds.chunk(3) | |
# 9. Denoising loop | |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
with self.progress_bar(total=num_inference_steps) as progress_bar: | |
for i, t in enumerate(timesteps): | |
latent_model_input = ( | |
torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
) | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
if i <= start_merge_step: | |
current_prompt_embeds = torch.cat( | |
[null_prompt_embeds, text_prompt_embeds], dim=0 | |
) | |
else: | |
current_prompt_embeds = torch.cat( | |
[null_prompt_embeds, augmented_prompt_embeds], dim=0 | |
) | |
# predict the noise residual 这一步魔改了一点东西 | |
noise_pred = self.unet( | |
latent_model_input, | |
t, | |
encoder_hidden_states=current_prompt_embeds, | |
cross_attention_kwargs=cross_attention_kwargs, | |
).sample | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
else: | |
assert 0, "Not Implemented" | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step( | |
noise_pred, t, latents, **extra_step_kwargs | |
).prev_sample | |
# call the callback, if provided | |
if i == len(timesteps) - 1 or ( | |
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
): | |
progress_bar.update() | |
if callback is not None and i % callback_steps == 0: | |
callback(i, t, latents) | |
if output_type == "latent": | |
image = latents | |
has_nsfw_concept = None | |
elif output_type == "pil": #默认的 | |
# 9.1 Post-processing | |
image = self.decode_latents(latents) | |
# 9.2 Run safety checker | |
if need_safetycheck: | |
image, has_nsfw_concept = self.run_safety_checker( | |
image, device, prompt_embeds.dtype | |
) | |
else: | |
has_nsfw_concept = None | |
# 9.3 Convert to PIL list | |
image = self.numpy_to_pil(image) | |
# 临时添加的,美肤效果,modelscope接收PIL对象,给一个BGR矩阵 | |
# 用了一下还是不要了,这个美肤模型失败概率有点大 | |
if retouching: | |
after_retouching = self.skin_retouching(image[0]) | |
if OutputKeys.OUTPUT_IMG in after_retouching: | |
image = [Image.fromarray(cv2.cvtColor(after_retouching[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB))] | |
else: | |
# 9.1 Post-processing | |
image = self.decode_latents(latents) | |
# 9.2 Run safety checker | |
image, has_nsfw_concept = self.run_safety_checker( | |
image, device, prompt_embeds.dtype | |
) | |
# Offload last model to CPU | |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: | |
self.final_offload_hook.offload() | |
if not return_dict: | |
return (image, has_nsfw_concept) | |
return StableDiffusionPipelineOutput( | |
images=image, nsfw_content_detected=has_nsfw_concept | |
) | |