"""Requires Transformer 4.28 and above, implementation may change according the Llama implementation.""" import logging import mmengine import torch import torch.nn as nn from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train from mmengine.device import get_device from transformers import LlamaForCausalLM, LlamaTokenizer from opencompass.registry import MM_MODELS @MM_MODELS.register_module('blip2-vicuna-instruct') class InstructBlipInferencer(Blip2Base): def __init__( self, prompt_constructor: dict, post_processor: dict, vit_model: str = 'eva_clip_g', img_size: int = 224, drop_path_rate: float = 0, use_grad_checkpoint: bool = False, vit_precision: str = 'fp16', freeze_vit: bool = True, num_query_token: int = 32, llm_model: str = '', sys_prompt: str = '', prompt: str = '', max_txt_len: int = 128, max_output_txt_len: int = 256, qformer_text_input: bool = True, low_resource: bool = False, mode: str = 'generation', is_caption_task=False, ): super().__init__() self.mode = mode self.prompt_constructor = mmengine.registry.build_from_cfg( prompt_constructor, MM_MODELS) self.post_processor = mmengine.registry.build_from_cfg( post_processor, MM_MODELS) self.tokenizer = self.init_tokenizer(truncation_side='left') self.visual_encoder, self.ln_vision = self.init_vision_encoder( vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision) if freeze_vit: for name, param in self.visual_encoder.named_parameters(): param.requires_grad = False self.visual_encoder = self.visual_encoder.eval() self.visual_encoder.train = disabled_train logging.info('freeze vision encoder') self.Qformer, self.query_tokens = self.init_Qformer( num_query_token, self.visual_encoder.num_features) if not qformer_text_input: self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None else: self.Qformer.resize_token_embeddings(len(self.tokenizer)) self.Qformer.cls = None self.llm_tokenizer = LlamaTokenizer.from_pretrained( llm_model, use_fast=False, truncation_side='left') if low_resource: self.llm_model = LlamaForCausalLM.from_pretrained( llm_model, torch_dtype=torch.float16, load_in_8bit=True, device_map={'': 0}) else: self.llm_model = LlamaForCausalLM.from_pretrained( llm_model, torch_dtype=torch.float16) self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) self.llm_tokenizer.add_special_tokens({'bos_token': ''}) self.llm_tokenizer.add_special_tokens({'eos_token': ''}) self.llm_tokenizer.add_special_tokens({'unk_token': ''}) self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) for name, param in self.llm_model.named_parameters(): param.requires_grad = False self.llm_proj = nn.Linear(self.Qformer.config.hidden_size, self.llm_model.config.hidden_size) self.max_txt_len = max_txt_len self.max_output_txt_len = max_output_txt_len self.sys_prompt = sys_prompt self.prompt = prompt self.is_caption_task = is_caption_task self._lemmatizer = None self.qformer_text_input = qformer_text_input def forward(self, batch): if self.mode == 'generation': return self.generate(batch) else: raise RuntimeError(f'Invalid mode "{self.mode}".') def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts): input_part_targets_len = [] llm_tokens = {'input_ids': [], 'attention_mask': []} for i in range(input_ids.size(0)): this_input_ones = input_atts[i].sum() input_part_targets_len.append(this_input_ones) llm_tokens['input_ids'].append( torch.cat([ input_ids[i][:this_input_ones], output_ids[i][1:], input_ids[i][this_input_ones:] ])) llm_tokens['attention_mask'].append( torch.cat([ input_atts[i][:this_input_ones], output_atts[i][1:], input_atts[i][this_input_ones:] ])) llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids']) llm_tokens['attention_mask'] = torch.stack( llm_tokens['attention_mask']) return llm_tokens, input_part_targets_len def pack_inputs(self, batch): images = [image.unsqueeze(0) for image in batch['inputs']] data_samples = [data_sample for data_sample in batch['data_samples']] images = torch.cat(images, dim=0).to(get_device()) inputs = {'image': images, 'data_samples': data_samples} return inputs @torch.no_grad() def generate( self, batch, use_nucleus_sampling=False, num_beams=5, max_length=256, min_length=1, top_p=0.9, repetition_penalty=1.5, length_penalty=1, num_captions=1, temperature=1, ): inputs = self.pack_inputs(batch) inputs = self.prompt_constructor(inputs) image = inputs['image'] prompt = inputs['prompt'] data_samples = inputs['data_samples'] self.llm_tokenizer.padding_side = 'left' bs = image.size(0) if isinstance(prompt, str): prompt = [prompt] * bs else: assert len( prompt ) == bs, 'The number of prompts must be equal to the batch size.' query_tokens = self.query_tokens.expand(bs, -1, -1) if self.qformer_text_input: text_Qformer = self.tokenizer( prompt, padding='longest', truncation=True, max_length=self.max_txt_len, return_tensors='pt', ).to(image.device) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) if self.qformer_text_input: query_output = self.Qformer.bert( text_Qformer.input_ids, attention_mask=Qformer_atts, query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) else: query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) inputs_llm = self.llm_proj( query_output.last_hidden_state[:, :query_tokens.size(1), :]) atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device) prompt = ['###Human: ' + p + '###Assistant:' for p in prompt] prompt = [self.sys_prompt + p for p in prompt] llm_tokens = self.llm_tokenizer(prompt, padding='longest', return_tensors='pt').to(image.device) with self.maybe_autocast(): inputs_embeds = self.llm_model.get_input_embeddings()( llm_tokens.input_ids) inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1) attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1) outputs = self.llm_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=use_nucleus_sampling, top_p=top_p, temperature=temperature, num_beams=num_beams, max_length=self.max_output_txt_len, min_length=min_length, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_return_sequences=num_captions, ) for i, data_sample in enumerate(data_samples): output_token = outputs[i] output_text = self.post_processor(output_token, self.llm_tokenizer) if self.is_caption_task: data_sample.pred_caption = output_text else: data_sample.pred_answer = output_text data_samples[i] = data_sample return data_samples