llama-3.1-8B-vision-378 / modeling_llamavision.py
qtnx's picture
fix attention mask warning
d5eed63 verified
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
AutoModelForCausalLM,
AutoModel,
SiglipImageProcessor,
)
from .configuration_llamavision import LlamavisionConfig
class ProjectionModule(nn.Module):
def __init__(self, mm_hidden_size=1152, hidden_size=4096):
super(ProjectionModule, self).__init__()
# Directly set up the sequential model
self.model = nn.Sequential(
nn.Linear(mm_hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x):
return self.model(x)
class Llamavision(PreTrainedModel):
config_class = LlamavisionConfig
def __init__(self, config):
super().__init__(config)
self.vision_model = AutoModel.from_config(self.config.vision_config)
self.text_model = AutoModelForCausalLM.from_config(self.config.text_config)
self.processor = SiglipImageProcessor()
self.mm_projector = ProjectionModule(
mm_hidden_size=config.vision_config.hidden_size,
hidden_size=config.text_config.hidden_size,
)
@property
def device(self):
return self.text_model.device
def encode_image(self, image):
image = image.convert("RGB")
image = self.processor(
images=image,
return_tensors="pt",
do_resize=True,
size={"height": 378, "width": 378},
)["pixel_values"].to(
device=self.vision_model.device, dtype=self.vision_model.dtype
)
with torch.no_grad():
return self.vision_model(image, output_hidden_states=True).hidden_states[-2]
def input_embeds(self, prompt, image_embeds, tokenizer):
def _tokenize(txt):
return tokenizer(
txt, return_tensors="pt", add_special_tokens=False
).input_ids.to(self.device)
text_emb = self.text_model.get_input_embeddings()
embeds = []
tokenized_prompt = _tokenize(prompt)
if (
tokenizer.bos_token_id is not None
and tokenized_prompt[0][0] != tokenizer.bos_token_id
):
embeds.append(
text_emb(torch.tensor([[tokenizer.bos_token_id]], device=self.device))
)
projected_image_embeds = self.mm_projector(image_embeds.to(self.device))
embeds.append(projected_image_embeds)
embeds.append(text_emb(tokenized_prompt))
return torch.cat(embeds, dim=1)
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def generate(
self,
image_embeds,
prompt,
tokenizer,
max_new_tokens=128,
**kwargs,
):
generate_config = {
"eos_token_id": [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
],
"bos_token_id": tokenizer.bos_token_id,
"pad_token_id": tokenizer.pad_token_id,
"max_new_tokens": max_new_tokens,
**kwargs,
}
with torch.no_grad():
inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
attention_mask = torch.ones(
inputs_embeds.shape[:2],
dtype=torch.long,
device=inputs_embeds.device
)
output_ids = self.text_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generate_config
)
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
def answer_question(self, image, question, tokenizer, **kwargs):
image_embeds = self.encode_image(image)
chat = [
{
"role": "system",
"content": "You are a helpful AI assistant that can see images and answer questions about them.",
},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
# Generate the answer
with torch.no_grad():
output = self.generate(
image_embeds=image_embeds,
prompt=prompt,
tokenizer=tokenizer,
**kwargs,
)[0]
# Clean and return the answer
cleaned_answer = output.strip()
return cleaned_answer