File size: 3,023 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import importlib
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from opencompass.registry import MM_MODELS
@MM_MODELS.register_module('otter-9b')
class Otter(nn.Module):
"""Inference code of OTTER.
Model details:
OTTER: a multi-modal model based on OpenFlamingo
(open-sourced version of DeepMind's Flamingo)
https://github.com/Luodian/Otter
Args:
model_path (str): The path of OTTER model
in Huggingface model hub format.
load_bit (str): The bit of OTTER model, can be "fp32" or "bf16".
mode (str): The mode of inference. Defaults to 'generation'.
"""
def __init__(self,
model_path,
load_bit,
prompt_constructor,
post_processor,
mode='generation') -> None:
super().__init__()
torch_dtype = torch.bfloat16 if load_bit == 'bf16' else torch.float32
otter_ai = importlib.import_module('otter_ai')
self.model = otter_ai.OtterForConditionalGeneration.from_pretrained(
model_path, torch_dtype=torch_dtype, device_map=get_device())
self.tokenizer = self.model.text_tokenizer
self.tokenizer.padding_side = 'left'
self.model_dtype = next(self.model.parameters()).dtype
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
if post_processor is not None:
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.mode = mode
def forward(self, batch):
if self.mode == 'generation':
return self.generate(batch)
elif self.mode == 'loss':
return self.loss(batch)
else:
raise RuntimeError(f'Invalid mode "{self.mode}".')
def generate(self, batch):
inputs = self.prompt_constructor(batch)
image = inputs['image']
prompt = inputs['prompt']
data_samples = inputs['data_samples']
vision_x = image.unsqueeze(1).unsqueeze(0).to(dtype=self.model_dtype)
lang_x = self.model.text_tokenizer([prompt], return_tensors='pt')
bad_words_id = self.model.text_tokenizer(['User:', 'GPT:']).input_ids
generated_text = self.model.generate(
vision_x=vision_x.to(self.model.device),
lang_x=lang_x['input_ids'].to(self.model.device),
attention_mask=lang_x['attention_mask'].to(self.model.device),
do_sample=False,
max_new_tokens=512,
num_beams=3,
bad_words_ids=bad_words_id,
no_repeat_ngram_size=3,
)
for i, data_sample in enumerate(data_samples):
output_text = self.post_processor(generated_text[i],
self.model.text_tokenizer)
data_sample.pred_answer = output_text
data_samples[i] = data_sample
return data_samples
|