File size: 5,964 Bytes
7055e81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import torch
import torch.nn as nn
from minigpt4.models.mini_gpt4 import MiniGPT4
from minigpt4.models.blip2 import Blip2Base, disabled_train
from transformers.models.gpt_neox import GPTNeoXForCausalLM
from transformers import AutoTokenizer
class CustomizedGPTNeoXForCausalLM(GPTNeoXForCausalLM):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
}
)
return model_inputs
class CustomizedMiniGPT4(Blip2Base):
"""
BLIP2 GPT-NeoX model.
"""
def __init__(
self,
gpt_neox_model="rinna/bilingual-gpt-neox-4b",
vit_model="eva_clip_g",
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
freeze_qformer=True,
num_query_token=32,
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
print('Loading VIT', flush=True)
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
)
if freeze_vit:
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
print("freeze vision encoder")
print('Loading VIT Done')
print('Loading Q-Former', flush=True)
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(url_or_filename=q_former_model)
if freeze_qformer:
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
print("freeze Qformer")
print('Loading Q-Former Done')
print('Loading LLM', flush=True)
self.gpt_neox_tokenizer = AutoTokenizer.from_pretrained(gpt_neox_model, use_fast=False)
if self.low_resource:
self.gpt_neox_model = CustomizedGPTNeoXForCausalLM.from_pretrained(
gpt_neox_model,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map={'': device_8bit}
)
else:
self.gpt_neox_model = CustomizedGPTNeoXForCausalLM.from_pretrained(
gpt_neox_model,
torch_dtype=torch.float16,
)
for name, param in self.gpt_neox_model.named_parameters():
param.requires_grad = False
print('Loading LLM Done')
self.llama_proj = nn.Linear(
self.Qformer.config.hidden_size, self.gpt_neox_model.config.hidden_size
)
def vit_to_cpu(self):
MiniGPT4.vit_to_cpu(self)
def encode_img(self, image):
inputs_gpt_neox, _ = MiniGPT4.encode_img(self, image)
return inputs_gpt_neox
def get_context_emb(self, prompt, img_list):
prompt_segs = prompt.split('<ImageHere>')
assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
seg_tokens = [
self.gpt_neox_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(self.device).input_ids
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.gpt_neox_model.gpt_neox.embed_in(seg_t) for seg_t in seg_tokens]
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
return mixed_embs |