File size: 6,862 Bytes
029d226 43f1a0b 029d226 d838dbe 029d226 43f1a0b 029d226 43f1a0b 029d226 d838dbe 029d226 d838dbe 029d226 d838dbe 029d226 d838dbe 029d226 d838dbe 029d226 d838dbe 029d226 b534628 029d226 d838dbe 029d226 d838dbe 029d226 43f1a0b 029d226 8e9df48 43f1a0b 029d226 8e9df48 029d226 43f1a0b 029d226 d838dbe 43f1a0b d838dbe 029d226 d838dbe 029d226 9c897db 029d226 d838dbe 029d226 d838dbe 029d226 d838dbe 8e9df48 029d226 43f1a0b 029d226 d838dbe 029d226 43f1a0b 029d226 8e9df48 d838dbe 6cd45a2 d838dbe 029d226 d838dbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import argparse
import sys
import torch
import torch.nn as nn
from PIL import Image
from transformers import (
AutoTokenizer,
BitsAndBytesConfig,
LlamaForCausalLM,
SiglipImageProcessor,
SiglipVisionModel,
)
from transformers import TextStreamer
def tokenizer_image_token(prompt, tokenizer, image_token_index=-200):
prompt_chunks = prompt.split("<image>")
tokenized_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks]
input_ids = tokenized_chunks[0]
for chunk in tokenized_chunks[1:]:
input_ids.append(image_token_index)
input_ids.extend(chunk[1:]) # Exclude BOS token on nonzero index
return torch.tensor(input_ids, dtype=torch.long)
def process_tensors(input_ids, image_features, embedding_layer):
# Find the index of -200 in input_ids
split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0]
# Split the input_ids at the index found, excluding -200
input_ids_1 = input_ids[:, :split_index]
input_ids_2 = input_ids[:, split_index + 1 :]
# Convert input_ids to embeddings
embeddings_1 = embedding_layer(input_ids_1)
embeddings_2 = embedding_layer(input_ids_2)
device = image_features.device
token_embeddings_part1 = embeddings_1.to(device)
token_embeddings_part2 = embeddings_2.to(device)
# Concatenate the token embeddings and image features
concatenated_embeddings = torch.cat(
[token_embeddings_part1, image_features, token_embeddings_part2], dim=1
)
# Create the corrected attention mask
attention_mask = torch.ones(
concatenated_embeddings.shape[:2], dtype=torch.long, device=device
)
return concatenated_embeddings, attention_mask
def initialize_models():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(
"unsloth/llama-3-8b-Instruct", use_fast=True
)
model = LlamaForCausalLM.from_pretrained(
"unsloth/llama-3-8b-Instruct",
torch_dtype=torch.float16,
device_map="auto",
quantization_config=bnb_config,
)
for param in model.base_model.parameters():
param.requires_grad = False
model_name = "google/siglip-so400m-patch14-384"
vision_model = SiglipVisionModel.from_pretrained(
model_name, torch_dtype=torch.float16
)
processor = SiglipImageProcessor.from_pretrained(model_name)
vision_model = vision_model.to("cuda")
return tokenizer, model, vision_model, processor
class ProjectionModule(nn.Module):
def __init__(self, mm_hidden_size, hidden_size):
super(ProjectionModule, self).__init__()
# Directly set up the sequential model
self.model = nn.Sequential(
nn.Linear(mm_hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x):
return self.model(x)
def load_projection_module(mm_hidden_size=1152, hidden_size=4096, device="cuda"):
projection_module = ProjectionModule(mm_hidden_size, hidden_size)
checkpoint = torch.load("./mm_projector.bin")
checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
projection_module.load_state_dict(checkpoint)
projection_module = projection_module.to(device).half()
return projection_module
def answer_question(
image_path, tokenizer, model, vision_model, processor, projection_module
):
image = Image.open(image_path).convert("RGB")
tokenizer.bos_token_id = None
tokenizer.eos_token = "<|eot_id|>"
try:
q = input("\nuser: ")
except EOFError:
q = ""
if not q:
print("no input detected. exiting.")
sys.exit()
question = "<image>" + q
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
input_ids = (
tokenizer_image_token(prompt, tokenizer)
.unsqueeze(0)
.to(model.device)
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
image_inputs = processor(
images=[image],
return_tensors="pt",
do_resize=True,
size={"height": 384, "width": 384},
).to("cuda")
image_inputs = image_inputs["pixel_values"].squeeze(0)
image_forward_outs = vision_model(
image_inputs.to(device="cuda", dtype=torch.float16).unsqueeze(0),
output_hidden_states=True,
)
image_features = image_forward_outs.hidden_states[-2]
projected_embeddings = projection_module(image_features).to("cuda")
embedding_layer = model.get_input_embeddings()
# text_embeddings = embedding_layer(input_ids)
new_embeds, attn_mask = process_tensors(
input_ids, projected_embeddings, embedding_layer
)
device = model.device
attn_mask = attn_mask.to(device)
new_embeds = new_embeds.to(device)
model_kwargs = {
"do_sample": True,
"temperature": 0.2,
"max_new_tokens": 2000,
"use_cache": True,
"streamer": streamer,
"pad_token_id": tokenizer.eos_token_id
}
while True:
print('assistant: ')
generated_ids = model.generate(
inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs
)[0]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
try:
q = input("\nuser: ")
except EOFError:
q = ""
if not q:
print("no input detected. exiting.")
new_text = (
generated_text
+ "<|start_header_id|>user<|end_header_id|>\n\n"
+ q
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
)
new_input_ids = tokenizer(new_text, return_tensors="pt").input_ids.to(
device
)
new_embeddings = embedding_layer(new_input_ids)
new_embeds = torch.cat([new_embeds, new_embeddings], dim=1)
attn_mask = torch.ones(new_embeds.shape[:2], device=device)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Answer questions based on an image")
parser.add_argument("-i", "--image", required=True, help="Path to the image file")
args = parser.parse_args()
tokenizer, model, vision_model, processor = initialize_models()
projection_module = load_projection_module()
answer_question(
args.image,
tokenizer,
model,
vision_model,
processor,
projection_module,
)
|