yibolu commited on
Commit
3308ae3
1 Parent(s): 6eca12e

update ipadapter

Browse files
lyrasd_model/module/lyrasd_ip_adapter.py CHANGED
@@ -45,17 +45,11 @@ class LyraIPAdapter:
45
  image_encoder_path=None,
46
  num_ip_tokens=4,
47
  ip_projection_dim=None,
48
- fp_ckpt=None,
49
- num_fp_tokens=1,
50
- fp_projection_dim=None,
51
  ):
52
  self.pipe = sd_pipe
53
  self.device = device
54
- self.fp_ckpt = fp_ckpt
55
  self.ip_ckpt = ip_ckpt
56
- self.num_fp_tokens = num_fp_tokens
57
  self.num_ip_tokens = num_ip_tokens
58
- self.fp_projection_dim = fp_projection_dim
59
  self.ip_projection_dim = ip_projection_dim
60
  self.sdxl = sdxl
61
  self.ip_plus = ip_plus
@@ -76,10 +70,6 @@ class LyraIPAdapter:
76
  else:
77
  self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
78
 
79
- # face proj model
80
- if self.fp_ckpt:
81
- self.face_proj_model = self.init_proj(self.fp_projection_dim, self.num_fp_tokens)
82
-
83
  self.load_ip_adapter()
84
 
85
  def init_proj_diffuser(self, state_dict):
@@ -131,16 +121,9 @@ class LyraIPAdapter:
131
  pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
132
  dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
133
  unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
134
-
135
- if self.fp_ckpt:
136
- state_dict = torch.load(self.fp_ckpt, map_location="cpu")
137
- self.face_proj_model.load_state_dict(state_dict["face_proj"])
138
- pretrained_path, subfolder, weight_name = parse_ckpt_path(self.fp_ckpt)
139
- dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
140
- unet.load_facein(dir_ipadapter, "fp16")
141
 
142
  @torch.inference_mode()
143
- def get_image_embeds(self, image=None, face_emb=None):
144
  image_prompt_embeds, uncond_image_prompt_embeds = None, None
145
 
146
  if image is not None:
@@ -160,22 +143,11 @@ class LyraIPAdapter:
160
  uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
161
  image_prompt_embeds = clip_image_prompt_embeds
162
  uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
163
-
164
- if face_emb is not None:
165
- face_embeds = face_emb.to(self.device, dtype=torch.float16)
166
- face_prompt_embeds = self.face_proj_model(face_embeds)
167
- uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
168
- if image_prompt_embeds is None:
169
- image_prompt_embeds = face_prompt_embeds
170
- uncond_image_prompt_embeds = uncond_face_prompt_embeds
171
- else:
172
- image_prompt_embeds = torch.cat([face_prompt_embeds, image_prompt_embeds], axis=1)
173
- uncond_image_prompt_embeds = torch.cat([uncond_face_prompt_embeds, uncond_image_prompt_embeds], dim=1)
174
 
175
  return image_prompt_embeds, uncond_image_prompt_embeds
176
 
177
  @torch.inference_mode()
178
- def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, face_emb=None, batch_size = 1, ip_scale=1.0, fp_scale=1.0, do_classifier_free_guidance=True):
179
  dict_tensor = {}
180
 
181
  if self.ip_ckpt and ip_scale>0:
@@ -199,91 +171,4 @@ class LyraIPAdapter:
199
  clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
200
  ip_image_embeds = self.image_proj_model(clip_image_embeds)
201
  dict_tensor["ip_hidden_states"] = ip_image_embeds
202
-
203
- if face_emb is not None and self.fp_ckpt and ip_scale>0:
204
- face_embeds = face_emb.to(self.device, dtype=torch.float16)
205
- face_prompt_embeds = self.face_proj_model(face_embeds)
206
- uncond_face_prompt_embeds = self.face_proj_model(torch.zeros_like(face_embeds))
207
- if do_classifier_free_guidance:
208
- fp_image_embeds = torch.cat([uncond_face_prompt_embeds, face_prompt_embeds])
209
- else:
210
- fp_image_embeds = face_prompt_embeds
211
- dict_tensor["fp_hidden_states"] = fp_image_embeds
212
  return dict_tensor
213
-
214
-
215
- if __name__ == "__main__":
216
- sys.path.append("/data/home/kiokaxiao/repos/LyraSD/python/lyrasd")
217
- from lyrasd_model import LyraSdXLTxt2ImgPipeline
218
-
219
- model_path = "/data/SharedModels/SD/checkpoints/stable-diffusion-xl-base-1.0/"
220
- # model_path = "/cfs-datasets/projects/VirtualIdol/models/base_model/sdxl/xxmix9realisticsdxlV1"
221
- lib_path = os.environ.get("LIBLYRASD_SO")
222
-
223
- dir_ip_adapter = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin"
224
- dir_facein = "/cfs-datasets/projects/VirtualIdol/models/FaceIn/v1/FaceIn_sdxl.bin"
225
- image_encoder_path = "/cfs-datasets/projects/VirtualIdol/models/ip_adapter/models/image_encoder"
226
-
227
- pipeline = LyraSdXLTxt2ImgPipeline(model_path, lib_path)
228
- pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, dir_facein, 1, 512)
229
- # pipeline.load_ip_adapter(dir_ip_adapter, True, image_encoder_path, 16,1024, "", 1, 512)
230
-
231
- face_emb = np.load("/data/home/kiokaxiao/repos/VidolImageDraw/girl.npy")
232
- face_emb = torch.Tensor(face_emb.reshape([1,-1]))
233
- ip_image = Image.open("/data/home/kiokaxiao/repos/VidolImageDraw/images/input_image.png").convert('RGB')
234
-
235
- generator = torch.Generator("cuda").manual_seed(123)
236
- batches = [2]
237
- sizes = [[512, 512], [768, 768], [1024, 1024]]
238
- # sizes = [[832, 640]]
239
- # sizes = [[1024, 1024]]
240
- running_cnt = 1
241
- do_bench = False
242
-
243
- ip_ratio = 1
244
- facein_ratio = 0.6
245
- extra_tensor_dict = {}
246
- extra_tensor_dict = pipeline.ip_adapter_helper.get_image_embeds_lyrasd(ip_image, None, face_emb, batches[0], ip_ratio, facein_ratio)
247
- param_scale_dict = {"facein_ratio": facein_ratio, "ip_ratio": ip_ratio}
248
- draw_cfg = {'width': 640,
249
- 'num_inference_steps': 30,
250
- 'height': 832,
251
- 'negative_prompt': '(worst quality, low quality, 3d, 2d, cartoons, sketch), tooth, open mouth',
252
- 'guidance_scale': 7,
253
- 'prompt': 'xxmixgirl, masterpiece, best quality, 1girl, solo, looking at viewer, simple background, hair ornament, black eyes, portrait',
254
- 'output_type': 'pil',
255
- 'extra_tensor_dict': extra_tensor_dict,
256
- "param_scale_dict": param_scale_dict}
257
-
258
-
259
- def warmup(draw_cfg):
260
- draw_cfg_wm = deepcopy(draw_cfg)
261
- draw_cfg_wm['num_inference_steps'] = 1
262
- pipeline(**draw_cfg_wm, generator= generator)
263
-
264
- if not do_bench:
265
- images = pipeline(**draw_cfg, generator= generator)
266
- else:
267
- for batch in batches:
268
- for height, width in sizes:
269
- draw_cfg['width'] = width
270
- draw_cfg['height'] = height
271
- draw_cfg['num_images_per_prompt'] = batch
272
- draw_cfg["num_inference_steps"] = 20
273
- warmup(draw_cfg)
274
- time_uses = []
275
- for x in range(running_cnt):
276
- start = time.perf_counter()
277
- draw_cfg['num_images_per_prompt'] = batch
278
- generator = torch.Generator("cuda").manual_seed(123)
279
- print("draw_cfg: ", draw_cfg.keys())
280
- print("draw_cfg: ", draw_cfg)
281
-
282
- images = pipeline(**draw_cfg, generator= generator)
283
- time_use = time.perf_counter() - start
284
- time_uses.append(time_use)
285
- print("bench", batch, width, sum(time_uses)/running_cnt, get_mem_use())
286
-
287
- print(type(images))
288
- images[0].save("t.png")
289
-
 
45
  image_encoder_path=None,
46
  num_ip_tokens=4,
47
  ip_projection_dim=None,
 
 
 
48
  ):
49
  self.pipe = sd_pipe
50
  self.device = device
 
51
  self.ip_ckpt = ip_ckpt
 
52
  self.num_ip_tokens = num_ip_tokens
 
53
  self.ip_projection_dim = ip_projection_dim
54
  self.sdxl = sdxl
55
  self.ip_plus = ip_plus
 
70
  else:
71
  self.image_proj_model = self.init_proj(self.ip_projection_dim, self.num_ip_tokens)
72
 
 
 
 
 
73
  self.load_ip_adapter()
74
 
75
  def init_proj_diffuser(self, state_dict):
 
121
  pretrained_path, subfolder, weight_name = parse_ckpt_path(self.ip_ckpt)
122
  dir_ipadapter = os.path.join(pretrained_path, "lyra_tran", subfolder, '.'.join(weight_name.split(".")[:-1]))
123
  unet.load_ip_adapter(dir_ipadapter, "", 1, "fp16")
 
 
 
 
 
 
 
124
 
125
  @torch.inference_mode()
126
+ def get_image_embeds(self, image=None):
127
  image_prompt_embeds, uncond_image_prompt_embeds = None, None
128
 
129
  if image is not None:
 
143
  uncond_clip_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
144
  image_prompt_embeds = clip_image_prompt_embeds
145
  uncond_image_prompt_embeds = uncond_clip_image_prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  return image_prompt_embeds, uncond_image_prompt_embeds
148
 
149
  @torch.inference_mode()
150
+ def get_image_embeds_lyrasd(self, image=None, ip_image_embeds=None, batch_size = 1, ip_scale=1.0, do_classifier_free_guidance=True):
151
  dict_tensor = {}
152
 
153
  if self.ip_ckpt and ip_scale>0:
 
171
  clip_image_embeds = torch.cat([uncond_clip_image_embeds, clip_image_embeds])
172
  ip_image_embeds = self.image_proj_model(clip_image_embeds)
173
  dict_tensor["ip_hidden_states"] = ip_image_embeds
 
 
 
 
 
 
 
 
 
 
174
  return dict_tensor