|
import importlib |
|
|
|
import mmengine |
|
import torch |
|
import torch.nn as nn |
|
from mmengine.device import get_device |
|
|
|
from opencompass.registry import MM_MODELS |
|
|
|
|
|
@MM_MODELS.register_module('otter-9b') |
|
class Otter(nn.Module): |
|
"""Inference code of OTTER. |
|
|
|
Model details: |
|
OTTER: a multi-modal model based on OpenFlamingo |
|
(open-sourced version of DeepMind's Flamingo) |
|
https://github.com/Luodian/Otter |
|
Args: |
|
model_path (str): The path of OTTER model |
|
in Huggingface model hub format. |
|
load_bit (str): The bit of OTTER model, can be "fp32" or "bf16". |
|
mode (str): The mode of inference. Defaults to 'generation'. |
|
""" |
|
|
|
def __init__(self, |
|
model_path, |
|
load_bit, |
|
prompt_constructor, |
|
post_processor, |
|
mode='generation') -> None: |
|
super().__init__() |
|
torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32 |
|
otter_ai = importlib.import_module('otter_ai') |
|
self.model = otter_ai.OtterForConditionalGeneration.from_pretrained( |
|
model_path, torch_dtype=torch_dtype, device_map=get_device()) |
|
self.tokenizer = self.model.text_tokenizer |
|
self.tokenizer.padding_side = 'left' |
|
self.model_dtype = next(self.model.parameters()).dtype |
|
self.prompt_constructor = mmengine.registry.build_from_cfg( |
|
prompt_constructor, MM_MODELS) |
|
if post_processor is not None: |
|
self.post_processor = mmengine.registry.build_from_cfg( |
|
post_processor, MM_MODELS) |
|
self.mode = mode |
|
|
|
def forward(self, batch): |
|
if self.mode == 'generation': |
|
return self.generate(batch) |
|
elif self.mode == 'loss': |
|
return self.loss(batch) |
|
else: |
|
raise RuntimeError(f'Invalid mode "{self.mode}".') |
|
|
|
def generate(self, batch): |
|
inputs = self.prompt_constructor(batch) |
|
image = inputs['image'] |
|
prompt = inputs['prompt'] |
|
data_samples = inputs['data_samples'] |
|
vision_x = image.unsqueeze(1).unsqueeze(0).to(dtype=self.model_dtype) |
|
lang_x = self.model.text_tokenizer([prompt], return_tensors='pt') |
|
bad_words_id = self.model.text_tokenizer(['User:', 'GPT:']).input_ids |
|
generated_text = self.model.generate( |
|
vision_x=vision_x.to(self.model.device), |
|
lang_x=lang_x['input_ids'].to(self.model.device), |
|
attention_mask=lang_x['attention_mask'].to(self.model.device), |
|
do_sample=False, |
|
max_new_tokens=512, |
|
num_beams=3, |
|
bad_words_ids=bad_words_id, |
|
no_repeat_ngram_size=3, |
|
) |
|
for i, data_sample in enumerate(data_samples): |
|
output_text = self.post_processor(generated_text[i], |
|
self.model.text_tokenizer) |
|
data_sample.pred_answer = output_text |
|
data_samples[i] = data_sample |
|
|
|
return data_samples |
|
|