chendl commited on
Commit
e511b8d
1 Parent(s): 9c4c4b2

update chat

Browse files
multimodal/open_flamingo/chat/conversation.py CHANGED
@@ -262,7 +262,6 @@ def preprocess_conv(data):
262
 
263
  class Chat:
264
  def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
265
- self.device = device
266
  self.model = model
267
  self.vis_processor = vis_processor
268
  self.tokenizer = tokenizer
@@ -400,20 +399,20 @@ class Chat:
400
  # self.conv.append_message(self.conv.roles[1], msg)
401
  return msg
402
 
403
- def get_context_emb(self, conv, img_list):
404
- prompt = conv.get_prompt()
405
- prompt_segs = prompt.split('<ImageHere>')
406
- assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
407
- seg_tokens = [
408
- self.model.llama_tokenizer(
409
- seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
410
- # only add bos to the first seg
411
- for i, seg in enumerate(prompt_segs)
412
- ]
413
- seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
414
- mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
415
- mixed_embs = torch.cat(mixed_embs, dim=1)
416
- return mixed_embs
417
 
418
  def evaluate_exp(
419
  model,
 
262
 
263
  class Chat:
264
  def __init__(self, model, vis_processor, tokenizer, vis_embed_size ):
 
265
  self.model = model
266
  self.vis_processor = vis_processor
267
  self.tokenizer = tokenizer
 
399
  # self.conv.append_message(self.conv.roles[1], msg)
400
  return msg
401
 
402
+ # def get_context_emb(self, conv, img_list):
403
+ # prompt = conv.get_prompt()
404
+ # prompt_segs = prompt.split('<ImageHere>')
405
+ # assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
406
+ # seg_tokens = [
407
+ # self.model.llama_tokenizer(
408
+ # seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
409
+ # # only add bos to the first seg
410
+ # for i, seg in enumerate(prompt_segs)
411
+ # ]
412
+ # seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
413
+ # mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
414
+ # mixed_embs = torch.cat(mixed_embs, dim=1)
415
+ # return mixed_embs
416
 
417
  def evaluate_exp(
418
  model,