jiuface commited on
Commit
d067625
1 Parent(s): 8d7d2d7
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -365,31 +365,40 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
365
  with calculateDuration("Set random seed"):
366
  seed = random.randint(0, MAX_SEED)
367
 
368
- # 编码提示词
369
  with calculateDuration("Encoding prompts"):
370
  # 编码背景提示词
 
 
 
 
 
371
  bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
372
- bg_prompt_embeds = pipe.text_encoder_2(bg_text_input.input_ids.to(device))[0]
373
  bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
374
 
375
  # 编码角色提示词
376
  character_prompt_embeds = []
377
  character_pooled_embeds = []
378
  for prompt in character_prompts:
 
 
 
 
379
  char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
380
- char_prompt_embeds = pipe.text_encoder_2(char_text_input.input_ids.to(device))[0]
381
  char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
 
382
  character_prompt_embeds.append(char_prompt_embeds)
383
  character_pooled_embeds.append(char_pooled_embeds)
384
 
385
  # 编码互动细节提示词
 
 
 
386
  details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
387
- details_prompt_embeds = pipe.text_encoder_2(details_text_input.input_ids.to(device))[0]
388
  details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
389
 
390
  # 合并背景和互动细节的嵌入
391
  prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
392
- pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=1)
393
 
394
  # 解析角色位置
395
  character_infos = []
 
365
  with calculateDuration("Set random seed"):
366
  seed = random.randint(0, MAX_SEED)
367
 
 
368
  with calculateDuration("Encoding prompts"):
369
  # 编码背景提示词
370
+ # 使用 tokenizer_2 和 text_encoder_2
371
+ bg_text_input_2 = pipe.tokenizer_2(prompt_bg, return_tensors="pt").to(device)
372
+ bg_prompt_embeds = pipe.text_encoder_2(bg_text_input_2.input_ids.to(device))[0]
373
+
374
+ # 使用 tokenizer 和 text_encoder
375
  bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
 
376
  bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
377
 
378
  # 编码角色提示词
379
  character_prompt_embeds = []
380
  character_pooled_embeds = []
381
  for prompt in character_prompts:
382
+ # 使用 tokenizer_2 和 text_encoder_2
383
+ char_text_input_2 = pipe.tokenizer_2(prompt, return_tensors="pt").to(device)
384
+ char_prompt_embeds = pipe.text_encoder_2(char_text_input_2.input_ids.to(device))[0]
385
+ # 使用 tokenizer 和 text_encoder
386
  char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
 
387
  char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
388
+
389
  character_prompt_embeds.append(char_prompt_embeds)
390
  character_pooled_embeds.append(char_pooled_embeds)
391
 
392
  # 编码互动细节提示词
393
+ details_text_input_2 = pipe.tokenizer_2(prompt_details, return_tensors="pt").to(device)
394
+ details_prompt_embeds = pipe.text_encoder_2(details_text_input_2.input_ids.to(device))[0]
395
+
396
  details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
 
397
  details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
398
 
399
  # 合并背景和互动细节的嵌入
400
  prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
401
+ pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=-1)
402
 
403
  # 解析角色位置
404
  character_infos = []