TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
3.75 kB
from typing import Optional
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import AutoModel, AutoTokenizer
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('visualglm')
class VisualGLM(nn.Module):
"""Inference code of VisualGLM.
We load the visualGLM model via Huggingface.
Args:
pretrained_path (str): Path to visualGLM checkpoint or repo id.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
gen_kwargs (dict): Customize generate function arguments.
Defaults to None.
"""
def __init__(self,
pretrained_path: str,
prompt_constructor: dict,
post_processor: dict,
is_caption_task: bool = False,
gen_kwargs: Optional[dict] = None) -> None:
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path,
trust_remote_code=True)
self.model = AutoModel.from_pretrained(pretrained_path,
trust_remote_code=True).half()
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
if gen_kwargs:
self.gen_kwargs = gen_kwargs
else:
self.gen_kwargs = dict(max_length=1024,
min_length=100,
do_sample=True,
temperature=0.8,
top_p=0.4,
top_k=100,
repetition_penalty=1.2)
self.is_caption_task = is_caption_task
def encode_by_tokenizer(self, prompt, image_position):
input0 = self.tokenizer.encode(prompt[:image_position],
add_special_tokens=False)
input1 = [self.tokenizer.unk_token_id] * self.model.image_length
input2 = self.tokenizer.encode(prompt[image_position:],
add_special_tokens=False)
input_all = sum([input0, input1, input2], [])
input_all = self.tokenizer.build_inputs_with_special_tokens(input_all)
input_all = torch.tensor(input_all, dtype=torch.long).to(get_device())
input_all = input_all.unsqueeze(0)
pre_image_len = len(input0)
return input_all, pre_image_len
def generate(self, batch):
# process input
image, prompt, data_sample, image_position = self.prompt_constructor(
batch)
image = image.to(self.model.dtype).to(get_device())
# tokenize
input_all, pre_image_len = self.encode_by_tokenizer(
prompt, image_position)
# build input param
inputs = {
'input_ids': input_all,
'pre_image_length': pre_image_len,
'images': image
}
# generate answer
outputs = self.model.generate(**inputs, **self.gen_kwargs)
# format output
outputs = outputs.tolist()[0][input_all.shape[1]:]
answer = self.post_processor(outputs, self.tokenizer)
if self.is_caption_task:
data_sample.pred_caption = answer
else:
data_sample.pred_answer = answer
return data_sample
def forward(self, batch):
return self.generate(batch)