|
import types |
|
from typing import Optional, Tuple |
|
|
|
import mmengine |
|
import torch |
|
import torch.nn as nn |
|
from mmengine.device import get_device |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
|
|
from opencompass.registry import MM_MODELS |
|
|
|
from .generation_utils import decode_tokens, make_context |
|
|
|
|
|
@MM_MODELS.register_module('qwen-vl-base') |
|
class QwenVLBase(nn.Module): |
|
"""Inference code of Qwen-VL. |
|
|
|
We load the Qwen model via Huggingface. |
|
Args: |
|
pretrained_path (str): Path to Qwen checkpoint or repo id. |
|
prompt_constructor (dict): The config of prompt constructor. |
|
post_processor (dict): The config of post processor. |
|
is_caption_task (bool): Whether the task is caption task. |
|
Defaults to False. |
|
commit_id (str): Use given version of Qwen-VL. |
|
Warning: the latest version may have some conflicts. |
|
Recommend to use the given default version. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
pretrained_path: str, |
|
prompt_constructor: dict = None, |
|
post_processor: dict = None, |
|
is_caption_task: bool = False, |
|
commit_id: str = '548275c8b99de56dec203c0e793be18e030f2f4c' |
|
) -> None: |
|
super().__init__() |
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path, |
|
trust_remote_code=True, |
|
revision=commit_id) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
pretrained_path, |
|
device_map=get_device(), |
|
trust_remote_code=True, |
|
revision=commit_id) |
|
self.model.generation_config = GenerationConfig.from_pretrained( |
|
pretrained_path, trust_remote_code=True, revision=commit_id) |
|
if prompt_constructor is not None: |
|
self.prompt_constructor = mmengine.registry.build_from_cfg( |
|
prompt_constructor, MM_MODELS) |
|
if post_processor is not None: |
|
self.post_processor = mmengine.registry.build_from_cfg( |
|
post_processor, MM_MODELS) |
|
else: |
|
self.post_processor = None |
|
self.is_caption_task = is_caption_task |
|
self.model.transformer.forward = types.MethodType( |
|
forward_hack, self.model.transformer) |
|
|
|
def _build_embeds(self, images, input_ids): |
|
|
|
images = self.model.transformer.visual(images) |
|
|
|
bos_pos = torch.where(input_ids == self.model.transformer.config. |
|
visual['image_start_id']) |
|
eos_pos = torch.where( |
|
input_ids == |
|
self.model.transformer.config.visual['image_start_id'] + 1) |
|
assert (bos_pos[0] == eos_pos[0]).all() |
|
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) |
|
|
|
inputs_embeds = self.model.transformer.wte(input_ids) |
|
|
|
for idx, (i, a, b) in enumerate(img_pos): |
|
inputs_embeds[i][a + 1:b] = images[idx] |
|
return inputs_embeds |
|
|
|
def generate(self, batch): |
|
images = batch.pop('inputs') |
|
images = torch.stack(images, dim=0) |
|
format_input = self.prompt_constructor(batch) |
|
query = self.tokenizer.from_list_format(format_input) |
|
|
|
inputs = self.tokenizer(query, return_tensors='pt') |
|
inputs = inputs.to(get_device()) |
|
input_ids, token_type_ids, attention_mask = inputs[ |
|
'input_ids'], inputs['token_type_ids'], inputs['attention_mask'] |
|
inputs_embeds = self._build_embeds(images, input_ids) |
|
pred = self.model.generate(input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids) |
|
response = self.post_processor(pred.cpu()[0]) |
|
|
|
data_sample = batch['data_samples'][0] |
|
if self.is_caption_task: |
|
data_sample.pred_caption = response |
|
else: |
|
data_sample.pred_answer = response |
|
return data_sample |
|
|
|
def forward(self, batch): |
|
return self.generate(batch) |
|
|
|
|
|
@MM_MODELS.register_module('qwen-vl-chat') |
|
class QwenVLChat(QwenVLBase): |
|
"""Inference code of Qwen-VL-Chat. |
|
|
|
We load the Qwen model via Huggingface. |
|
Args: |
|
pretrained_path (str): Path to Qwen checkpoint or repo id. |
|
prompt_constructor (dict): The config of prompt constructor. |
|
post_processor (dict): The config of post processor. |
|
is_caption_task (bool): Whether the task is caption task. |
|
Defaults to False. |
|
""" |
|
|
|
def __init__(self, |
|
pretrained_path: str, |
|
prompt_constructor: dict = None, |
|
post_processor: dict = None, |
|
is_caption_task: bool = False) -> None: |
|
super().__init__(pretrained_path, prompt_constructor, post_processor, |
|
is_caption_task) |
|
|
|
def generate(self, batch): |
|
images = batch.pop('inputs') |
|
images = torch.stack(images, dim=0) |
|
format_input = self.prompt_constructor(batch) |
|
query = self.tokenizer.from_list_format(format_input) |
|
|
|
raw_text, context_tokens = make_context( |
|
self.tokenizer, |
|
query, |
|
system='You are a helpful assistant.', |
|
chat_format=self.model.generation_config.chat_format, |
|
) |
|
|
|
input_ids = torch.tensor([context_tokens]).to(get_device()) |
|
|
|
inputs_embeds = self._build_embeds(images, input_ids) |
|
pred = self.model.generate(input_ids=input_ids, |
|
inputs_embeds=inputs_embeds) |
|
|
|
response = decode_tokens( |
|
pred[0], |
|
self.tokenizer, |
|
raw_text_len=len(raw_text), |
|
context_length=len(context_tokens), |
|
chat_format=self.model.generation_config.chat_format, |
|
verbose=False, |
|
errors='replace') |
|
|
|
if self.post_processor: |
|
response = self.post_processor(response) |
|
|
|
data_sample = batch['data_samples'][0] |
|
if self.is_caption_task: |
|
data_sample.pred_caption = response |
|
else: |
|
data_sample.pred_answer = response |
|
return data_sample |
|
|
|
|
|
def forward_hack(self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None): |
|
if past_key_values is None and input_ids is not None and torch.any( |
|
input_ids == self.config.visual['image_start_id']): |
|
bos_pos = torch.where( |
|
input_ids == self.config.visual['image_start_id']) |
|
eos_pos = torch.where( |
|
input_ids == self.config.visual['image_start_id'] + 1) |
|
assert (bos_pos[0] == eos_pos[0]).all() |
|
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) |
|
images = [] |
|
for i, a, b in img_pos: |
|
image = input_ids[i][a + 1:b - 1].tolist() |
|
image = image[:image.index(self.config.visual['image_start_id'] + |
|
2)] |
|
images.append(bytes(image).decode('utf-8')) |
|
|
|
images = self.visual.encode(images) |
|
assert images.shape[0] == len(images) |
|
else: |
|
images = None |
|
|
|
output_attentions = (output_attentions if output_attentions is not None |
|
else self.config.output_attentions) |
|
output_hidden_states = (output_hidden_states if output_hidden_states |
|
is not None else self.config.output_hidden_states) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = (return_dict |
|
if return_dict is not None else self.config.use_return_dict) |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
'You cannot specify both input_ids and inputs_embeds at the same time' |
|
) |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
batch_size = input_ids.shape[0] |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
batch_size = inputs_embeds.shape[0] |
|
else: |
|
raise ValueError( |
|
'You have to specify either input_ids or inputs_embeds') |
|
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.view(-1, input_shape[-1]) |
|
if position_ids is not None: |
|
position_ids = position_ids.view(-1, input_shape[-1]) |
|
|
|
if past_key_values is None: |
|
past_length = 0 |
|
past_key_values = tuple([None] * len(self.h)) |
|
else: |
|
past_length = past_key_values[0][0].size(-2) |
|
|
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
past_length, |
|
input_shape[-1] + past_length, |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) |
|
|
|
encoder_attention_mask = None |
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.wte(input_ids) |
|
|
|
if batch_size <= 0: |
|
raise ValueError('batch_size has to be defined and > 0') |
|
attention_mask = self._prepare_decoder_attention_mask( |
|
attention_mask, input_shape, inputs_embeds, past_length) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
hidden_states = self.drop(hidden_states) |
|
if images is not None: |
|
for idx, (i, a, b) in enumerate(img_pos): |
|
hidden_states[i][a + 1:b] = images[idx] |
|
output_shape = input_shape + (hidden_states.size(-1), ) |
|
|
|
presents = () if use_cache else None |
|
all_self_attentions = () if output_attentions else None |
|
all_hidden_states = () if output_hidden_states else None |
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states, ) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
def create_custom_forward(module): |
|
|
|
def custom_forward(*inputs): |
|
|
|
return module(*inputs, use_cache, output_attentions) |
|
|
|
return custom_forward |
|
|
|
outputs = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
None, |
|
attention_mask, |
|
head_mask[i], |
|
encoder_hidden_states, |
|
encoder_attention_mask, |
|
) |
|
else: |
|
outputs = block( |
|
hidden_states, |
|
layer_past=layer_past, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask[i], |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
if use_cache is True: |
|
presents = presents + (outputs[2 if output_attentions else 1], ) |
|
|
|
if output_attentions: |
|
all_self_attentions = all_self_attentions + (outputs[1], ) |
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
hidden_states = hidden_states.view(output_shape) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states, ) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, presents, all_hidden_states] |
|
if v is not None) |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=presents, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attentions, |
|
) |
|
|