qtnx commited on
Commit
d5eed63
1 Parent(s): 8abd220

fix attention mask warning

Browse files

i literally cannot test this code but normally it should work

Files changed (1) hide show
  1. modeling_llamavision.py +10 -1
modeling_llamavision.py CHANGED
@@ -105,8 +105,17 @@ class Llamavision(PreTrainedModel):
105
 
106
  with torch.no_grad():
107
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
 
 
 
 
 
 
 
108
  output_ids = self.text_model.generate(
109
- inputs_embeds=inputs_embeds, **generate_config
 
 
110
  )
111
 
112
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
 
105
 
106
  with torch.no_grad():
107
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
108
+
109
+ attention_mask = torch.ones(
110
+ inputs_embeds.shape[:2],
111
+ dtype=torch.long,
112
+ device=inputs_embeds.device
113
+ )
114
+
115
  output_ids = self.text_model.generate(
116
+ inputs_embeds=inputs_embeds,
117
+ attention_mask=attention_mask,
118
+ **generate_config
119
  )
120
 
121
  return tokenizer.batch_decode(output_ids, skip_special_tokens=True)