|
from typing import Optional |
|
|
|
import mmengine |
|
import torch |
|
import torch.nn as nn |
|
from mmengine.device import get_device |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
from opencompass.registry import MM_MODELS |
|
|
|
|
|
@MM_MODELS.register_module('visualglm') |
|
class VisualGLM(nn.Module): |
|
"""Inference code of VisualGLM. |
|
|
|
We load the visualGLM model via Huggingface. |
|
Args: |
|
pretrained_path (str): Path to visualGLM checkpoint or repo id. |
|
prompt_constructor (dict): The config of prompt constructor. |
|
post_processor (dict): The config of post processor. |
|
is_caption_task (bool): Whether the task is caption task. |
|
Defaults to False. |
|
gen_kwargs (dict): Customize generate function arguments. |
|
Defaults to None. |
|
""" |
|
|
|
def __init__(self, |
|
pretrained_path: str, |
|
prompt_constructor: dict, |
|
post_processor: dict, |
|
is_caption_task: bool = False, |
|
gen_kwargs: Optional[dict] = None) -> None: |
|
super().__init__() |
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path, |
|
trust_remote_code=True) |
|
self.model = AutoModel.from_pretrained(pretrained_path, |
|
trust_remote_code=True).half() |
|
self.prompt_constructor = mmengine.registry.build_from_cfg( |
|
prompt_constructor, MM_MODELS) |
|
self.post_processor = mmengine.registry.build_from_cfg( |
|
post_processor, MM_MODELS) |
|
|
|
if gen_kwargs: |
|
self.gen_kwargs = gen_kwargs |
|
else: |
|
self.gen_kwargs = dict(max_length=1024, |
|
min_length=100, |
|
do_sample=True, |
|
temperature=0.8, |
|
top_p=0.4, |
|
top_k=100, |
|
repetition_penalty=1.2) |
|
|
|
self.is_caption_task = is_caption_task |
|
|
|
def encode_by_tokenizer(self, prompt, image_position): |
|
|
|
input0 = self.tokenizer.encode(prompt[:image_position], |
|
add_special_tokens=False) |
|
input1 = [self.tokenizer.unk_token_id] * self.model.image_length |
|
input2 = self.tokenizer.encode(prompt[image_position:], |
|
add_special_tokens=False) |
|
input_all = sum([input0, input1, input2], []) |
|
input_all = self.tokenizer.build_inputs_with_special_tokens(input_all) |
|
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device()) |
|
input_all = input_all.unsqueeze(0) |
|
|
|
pre_image_len = len(input0) |
|
|
|
return input_all, pre_image_len |
|
|
|
def generate(self, batch): |
|
|
|
image, prompt, data_sample, image_position = self.prompt_constructor( |
|
batch) |
|
image = image.to(self.model.dtype).to(get_device()) |
|
|
|
|
|
input_all, pre_image_len = self.encode_by_tokenizer( |
|
prompt, image_position) |
|
|
|
|
|
inputs = { |
|
'input_ids': input_all, |
|
'pre_image_length': pre_image_len, |
|
'images': image |
|
} |
|
|
|
|
|
outputs = self.model.generate(**inputs, **self.gen_kwargs) |
|
|
|
|
|
outputs = outputs.tolist()[0][input_all.shape[1]:] |
|
answer = self.post_processor(outputs, self.tokenizer) |
|
|
|
if self.is_caption_task: |
|
data_sample.pred_caption = answer |
|
else: |
|
data_sample.pred_answer = answer |
|
|
|
return data_sample |
|
|
|
def forward(self, batch): |
|
return self.generate(batch) |
|
|