import os import sys import mmengine import torch import torch.nn as nn from mmengine.device import get_device from opencompass.registry import MM_MODELS def load_package(): """Load required packages from llama_adapter_v2_multimodal7b.""" current_file_path = os.path.abspath(__file__) current_folder_path = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_folder_path, 'mPLUG-Owl')) # noqa from mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration from mplug_owl.processing_mplug_owl import (MplugOwlImageProcessor, MplugOwlProcessor) from mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer sys.path.pop(-1) return MplugOwlForConditionalGeneration, MplugOwlImageProcessor, MplugOwlProcessor, MplugOwlTokenizer # noqa MplugOwlForConditionalGeneration, MplugOwlImageProcessor, MplugOwlProcessor, MplugOwlTokenizer = load_package( # noqa ) # noqa @MM_MODELS.register_module('mplug_owl_7b') class MplugOwl(nn.Module): def __init__(self, prompt_constructor: dict, post_processor: dict, model_path='MAGAer13/mplug-owl-llama-7b', mode: str = 'generation'): super().__init__() pretrained_ckpt = model_path # import pdb;pdb.set_trace() print(pretrained_ckpt) self.model = MplugOwlForConditionalGeneration.from_pretrained( pretrained_ckpt, torch_dtype=torch.bfloat16, ).cuda() self.image_processor = MplugOwlImageProcessor.from_pretrained( pretrained_ckpt) self.tokenizer = MplugOwlTokenizer.from_pretrained(pretrained_ckpt) self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer) self.generate_kwargs = { 'do_sample': False, 'top_k': 5, 'max_length': 20, 'num_beams': 3, } self.prompt_constructor = mmengine.registry.build_from_cfg( prompt_constructor, MM_MODELS) if post_processor is not None: self.post_processor = mmengine.registry.build_from_cfg( post_processor, MM_MODELS) self.mode = mode def forward(self, batch): if self.mode == 'generation': return self.generate(batch) def generate(self, batch): images = [image.unsqueeze(0) for image in batch['inputs']] data_samples = [data_sample for data_sample in batch['data_samples']] images = torch.cat(images, dim=0).to(get_device()) inputs = {'image': images, 'data_samples': data_samples} inputs = self.prompt_constructor(inputs) image = inputs['image'] prompt = inputs['prompt'][0] data_samples = inputs['data_samples'] data_sample = data_samples[0] owl_template = """The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Human: Human: {text_input} AI: """ prompt = owl_template.format(text_input=prompt) inputs = self.processor(text=[prompt], return_tensors='pt') inputs['pixel_values'] = image # inputs['pixel_values'] = torch.zeros_like(samples['image']) inputs = { k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items() } inputs = {k: v.to(self.model.device) for k, v in inputs.items()} with torch.no_grad(): res = self.model.generate(**inputs, **self.generate_kwargs) output_text = self.tokenizer.decode(res.tolist()[0], skip_special_tokens=True) output_text = self.post_processor(output_text) data_sample.pred_answer = output_text return data_sample