|
import torch |
|
import torch.nn as nn |
|
|
|
from minigpt4.models.mini_gpt4 import MiniGPT4 |
|
from minigpt4.models.blip2 import Blip2Base, disabled_train |
|
|
|
from transformers.models.gpt_neox import GPTNeoXForCausalLM |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class CustomizedGPTNeoXForCausalLM(GPTNeoXForCausalLM): |
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): |
|
input_shape = input_ids.shape |
|
|
|
|
|
if past_key_values and past_key_values[0] is not None: |
|
input_ids = input_ids[:, -1:] |
|
|
|
position_ids = kwargs.get("position_ids", None) |
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -1].unsqueeze(-1) |
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = input_ids.new_ones(input_shape) |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"attention_mask": attention_mask, |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
} |
|
) |
|
return model_inputs |
|
|
|
|
|
class CustomizedMiniGPT4(Blip2Base): |
|
""" |
|
BLIP2 GPT-NeoX model. |
|
""" |
|
def __init__( |
|
self, |
|
gpt_neox_model="rinna/bilingual-gpt-neox-4b", |
|
vit_model="eva_clip_g", |
|
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", |
|
img_size=224, |
|
drop_path_rate=0, |
|
use_grad_checkpoint=False, |
|
vit_precision="fp16", |
|
freeze_vit=True, |
|
freeze_qformer=True, |
|
num_query_token=32, |
|
low_resource=False, |
|
device_8bit=0, |
|
): |
|
super().__init__() |
|
|
|
self.tokenizer = self.init_tokenizer() |
|
self.low_resource = low_resource |
|
|
|
print('Loading VIT', flush=True) |
|
self.visual_encoder, self.ln_vision = self.init_vision_encoder( |
|
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision |
|
) |
|
if freeze_vit: |
|
for name, param in self.visual_encoder.named_parameters(): |
|
param.requires_grad = False |
|
self.visual_encoder = self.visual_encoder.eval() |
|
self.visual_encoder.train = disabled_train |
|
for name, param in self.ln_vision.named_parameters(): |
|
param.requires_grad = False |
|
self.ln_vision = self.ln_vision.eval() |
|
self.ln_vision.train = disabled_train |
|
print("freeze vision encoder") |
|
print('Loading VIT Done') |
|
|
|
print('Loading Q-Former', flush=True) |
|
self.Qformer, self.query_tokens = self.init_Qformer( |
|
num_query_token, self.visual_encoder.num_features |
|
) |
|
self.Qformer.cls = None |
|
self.Qformer.bert.embeddings.word_embeddings = None |
|
self.Qformer.bert.embeddings.position_embeddings = None |
|
for layer in self.Qformer.bert.encoder.layer: |
|
layer.output = None |
|
layer.intermediate = None |
|
self.load_from_pretrained(url_or_filename=q_former_model) |
|
|
|
if freeze_qformer: |
|
for name, param in self.Qformer.named_parameters(): |
|
param.requires_grad = False |
|
self.Qformer = self.Qformer.eval() |
|
self.Qformer.train = disabled_train |
|
self.query_tokens.requires_grad = False |
|
print("freeze Qformer") |
|
print('Loading Q-Former Done') |
|
|
|
print('Loading LLM', flush=True) |
|
self.gpt_neox_tokenizer = AutoTokenizer.from_pretrained(gpt_neox_model, use_fast=False) |
|
|
|
if self.low_resource: |
|
self.gpt_neox_model = CustomizedGPTNeoXForCausalLM.from_pretrained( |
|
gpt_neox_model, |
|
torch_dtype=torch.float16, |
|
load_in_8bit=True, |
|
device_map={'': device_8bit} |
|
) |
|
else: |
|
self.gpt_neox_model = CustomizedGPTNeoXForCausalLM.from_pretrained( |
|
gpt_neox_model, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
for name, param in self.gpt_neox_model.named_parameters(): |
|
param.requires_grad = False |
|
print('Loading LLM Done') |
|
|
|
self.llama_proj = nn.Linear( |
|
self.Qformer.config.hidden_size, self.gpt_neox_model.config.hidden_size |
|
) |
|
|
|
def vit_to_cpu(self): |
|
MiniGPT4.vit_to_cpu(self) |
|
|
|
def encode_img(self, image): |
|
inputs_gpt_neox, _ = MiniGPT4.encode_img(self, image) |
|
return inputs_gpt_neox |
|
|
|
def get_context_emb(self, prompt, img_list): |
|
prompt_segs = prompt.split('<ImageHere>') |
|
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." |
|
seg_tokens = [ |
|
self.gpt_neox_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(self.device).input_ids |
|
for i, seg in enumerate(prompt_segs) |
|
] |
|
seg_embs = [self.gpt_neox_model.gpt_neox.embed_in(seg_t) for seg_t in seg_tokens] |
|
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] |
|
mixed_embs = torch.cat(mixed_embs, dim=1) |
|
return mixed_embs |