Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) 2023, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
import logging | |
import random | |
import os | |
import torch | |
from torch.cuda.amp import autocast as autocast | |
import torch.nn as nn | |
from minigpt4.common.registry import registry | |
from minigpt4.models.blip2 import Blip2Base, disabled_train | |
from minigpt4.models.modeling_llama import LlamaForCausalLM | |
from transformers import LlamaTokenizer | |
class MiniGPT4(Blip2Base): | |
""" | |
BLIP2 GPT-LLAMA model. | |
""" | |
PRETRAINED_MODEL_CONFIG_DICT = { | |
"pretrain_vicuna": "configs/models/minigpt4.yaml", | |
} | |
def __init__( | |
self, | |
vit_model="eva_clip_g", | |
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", | |
img_size=224, | |
drop_path_rate=0, | |
use_grad_checkpoint=False, | |
vit_precision="fp16", | |
freeze_vit=True, | |
freeze_qformer=True, | |
num_query_token=32, | |
llama_model="", | |
llama_cache_dir='', | |
prompt_path="", | |
prompt_template="", | |
max_txt_len=32, | |
end_sym='\n', | |
): | |
super().__init__() | |
self.tokenizer = self.init_tokenizer() | |
print('Loading VIT') | |
self.visual_encoder, self.ln_vision = self.init_vision_encoder( | |
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision | |
) | |
if freeze_vit: | |
for name, param in self.visual_encoder.named_parameters(): | |
param.requires_grad = False | |
self.visual_encoder = self.visual_encoder.eval() | |
self.visual_encoder.train = disabled_train | |
for name, param in self.ln_vision.named_parameters(): | |
param.requires_grad = False | |
self.ln_vision = self.ln_vision.eval() | |
self.ln_vision.train = disabled_train | |
logging.info("freeze vision encoder") | |
print('Loading VIT Done') | |
print('Loading Q-Former') | |
self.Qformer, self.query_tokens = self.init_Qformer( | |
num_query_token, self.visual_encoder.num_features | |
) | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
self.load_from_pretrained(url_or_filename=q_former_model) | |
if freeze_qformer: | |
for name, param in self.Qformer.named_parameters(): | |
param.requires_grad = False | |
self.Qformer = self.Qformer.eval() | |
self.Qformer.train = disabled_train | |
self.query_tokens.requires_grad = False | |
logging.info("freeze Qformer") | |
print('Loading Q-Former Done') | |
print('Loading LLAMA') | |
self.llama_tokenizer = LlamaTokenizer.from_pretrained('Vision-CAIR/vicuna-7b', use_fast=False, use_auth_token=True) | |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token | |
if llama_cache_dir: | |
self.llama_model = LlamaForCausalLM.from_pretrained( | |
'Vision-CAIR/vicuna-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True | |
) | |
else: | |
self.llama_model = LlamaForCausalLM.from_pretrained( | |
'Vision-CAIR/vicuna-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True | |
) | |
for name, param in self.llama_model.named_parameters(): | |
param.requires_grad = False | |
print('Loading LLAMA Done') | |
self.llama_proj = nn.Linear( | |
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size | |
) | |
self.max_txt_len = max_txt_len | |
self.end_sym = end_sym | |
if prompt_path: | |
with open(prompt_path, 'r') as f: | |
raw_prompts = f.read().splitlines() | |
filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt] | |
self.prompt_list = [prompt_template.format(p) for p in filted_prompts] | |
print('Load {} training prompts'.format(len(self.prompt_list))) | |
print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) | |
else: | |
self.prompt_list = [] | |
def vit_to_cpu(self): | |
self.ln_vision.to("cpu") | |
self.ln_vision.float() | |
self.visual_encoder.to("cpu") | |
self.visual_encoder.float() | |
def encode_img(self, image): | |
device = image.device | |
self.vit_to_cpu() | |
image = image.to("cpu") | |
with self.maybe_autocast(): | |
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_llama = self.llama_proj(query_output.last_hidden_state) | |
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) | |
return inputs_llama, atts_llama | |
def prompt_wrap(self, img_embeds, atts_img, prompt): | |
if prompt: | |
batch_size = img_embeds.shape[0] | |
p_before, p_after = prompt.split('<ImageHere>') | |
p_before_tokens = self.llama_tokenizer( | |
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) | |
p_after_tokens = self.llama_tokenizer( | |
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) | |
p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) | |
p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) | |
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) | |
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) | |
return wrapped_img_embeds, wrapped_atts_img | |
else: | |
return img_embeds, atts_img | |
def forward(self, samples): | |
image = samples["image"] | |
img_embeds, atts_img = self.encode_img(image) | |
if hasattr(samples, 'question_split'): # VQA dataset | |
print('VQA Batch') | |
vqa_prompt = '###Human: <Img><ImageHere></Img> ' | |
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt) | |
elif self.prompt_list: | |
prompt = random.choice(self.prompt_list) | |
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) | |
self.llama_tokenizer.padding_side = "right" | |
text = [t + self.end_sym for t in samples["text_input"]] | |
to_regress_tokens = self.llama_tokenizer( | |
text, | |
return_tensors="pt", | |
padding="longest", | |
truncation=True, | |
max_length=self.max_txt_len, | |
add_special_tokens=False | |
).to(image.device) | |
targets = to_regress_tokens.input_ids.masked_fill( | |
to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 | |
) | |
empty_targets = ( | |
torch.ones([atts_img.shape[0], atts_img.shape[1]+1], | |
dtype=torch.long).to(image.device).fill_(-100) # plus one for bos | |
) | |
targets = torch.cat([empty_targets, targets], dim=1) | |
batch_size = img_embeds.shape[0] | |
bos = torch.ones([batch_size, 1], | |
dtype=to_regress_tokens.input_ids.dtype, | |
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id | |
bos_embeds = self.llama_model.model.embed_tokens(bos) | |
atts_bos = atts_img[:, :1] | |
to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) | |
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) | |
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) | |
with self.maybe_autocast(): | |
outputs = self.llama_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
return_dict=True, | |
labels=targets, | |
) | |
loss = outputs.loss | |
return {"loss": loss} | |
def from_config(cls, cfg): | |
vit_model = cfg.get("vit_model", "eva_clip_g") | |
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") | |
img_size = cfg.get("image_size") | |
num_query_token = cfg.get("num_query_token") | |
llama_model = cfg.get("llama_model") | |
drop_path_rate = cfg.get("drop_path_rate", 0) | |
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) | |
vit_precision = cfg.get("vit_precision", "fp16") | |
freeze_vit = cfg.get("freeze_vit", True) | |
freeze_qformer = cfg.get("freeze_qformer", True) | |
llama_cache_dir = cfg.get("llama_cache_dir", "") | |
prompt_path = cfg.get("prompt_path", "") | |
prompt_template = cfg.get("prompt_template", "") | |
max_txt_len = cfg.get("max_txt_len", 32) | |
end_sym = cfg.get("end_sym", '\n') | |
model = cls( | |
vit_model=vit_model, | |
q_former_model=q_former_model, | |
img_size=img_size, | |
drop_path_rate=drop_path_rate, | |
use_grad_checkpoint=use_grad_checkpoint, | |
vit_precision=vit_precision, | |
freeze_vit=freeze_vit, | |
freeze_qformer=freeze_qformer, | |
llama_cache_dir=llama_cache_dir, | |
num_query_token=num_query_token, | |
llama_model=llama_model, | |
prompt_path=prompt_path, | |
prompt_template=prompt_template, | |
max_txt_len=max_txt_len, | |
end_sym=end_sym | |
) | |
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 | |
if ckpt_path: | |
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) | |
ckpt = torch.load(ckpt_path, map_location="cpu") | |
msg = model.load_state_dict(ckpt['model'], strict=False) | |
return model | |