File size: 5,456 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import importlib
import os
import sys
import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import StoppingCriteria
from opencompass.registry import MM_MODELS
IMAGE_TOKEN_INDEX = -200
def load_package():
"""Load required packages from LLaVA."""
current_file_path = os.path.abspath(__file__)
current_folder_path = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_folder_path, 'LLaVA')) # noqa
return
class KeywordsStoppingCriteria(StoppingCriteria):
"""Keyword stopping criteria implemented for llava."""
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
**kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(output_ids[:,
self.start_len:],
skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
@MM_MODELS.register_module('llava')
class LLaVA(nn.Module):
"""Inference code of LLaVA. Need to clone LLaVA official repo first. Please
check out the README in config.
Args:
model_path (str): The path of llava checkpoint.
prompt_constructor (dict): The config of prompt constructor.
post_processor (dict): The config of post processor.
is_caption_task (bool): Whether the task is caption task.
Defaults to False.
"""
def __init__(
self,
model_path: str,
prompt_constructor: dict,
post_processor: dict,
is_caption_task: bool = False,
) -> None:
super().__init__()
self.dtype = torch.float16
self.is_caption_task = is_caption_task
# load LLaVA modules
load_package()
mm_utils = importlib.import_module('llava.mm_utils')
builder = importlib.import_module('llava.model.builder')
# load pretrained LLaVA
# Note: When encounters with device related errors,
# try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
model_name = mm_utils.get_model_name_from_path(model_path)
tokenizer, model, _, _ = builder.load_pretrained_model(
model_path, None, model_name)
vision_tower = model.get_vision_tower()
vision_tower.to(device=get_device(), dtype=self.dtype)
model.to(device=get_device(), dtype=self.dtype)
# load prompt constructor and post processor
if 'v1' in model_path.lower():
conv_mode = 'llava_v1'
elif 'mpt' in model_path.lower():
conv_mode = 'mpt_multimodal'
else:
conv_mode = 'multimodal'
mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
False)
prompt_constructor.update({
'conv_mode': conv_mode,
'mm_use_im_start_end': mm_use_im_start_end
})
self.prompt_constructor = mmengine.registry.build_from_cfg(
prompt_constructor, MM_MODELS)
self.post_processor = mmengine.registry.build_from_cfg(
post_processor, MM_MODELS)
self.model = model
self.tokenizer = tokenizer
def generate(self, batch):
prompt, stop_str = self.prompt_constructor(batch)
keywords = [stop_str]
data_sample = batch['data_samples'][0]
image = batch['inputs'][0].unsqueeze(0)
if image is not None:
images = image.to(get_device())
else:
images = None
mm_utils = importlib.import_module('llava.mm_utils')
input_ids = mm_utils.tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
return_tensors='pt').unsqueeze(0).to(get_device())
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
input_ids)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images.half(),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids !=
output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(
f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids' # noqa
)
outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
skip_special_tokens=True)[0]
output_text = self.post_processor(outputs, stop_str)
if self.is_caption_task:
data_sample.pred_caption = output_text
else:
data_sample.pred_answer = output_text
return data_sample
def forward(self, batch):
return self.generate(batch)
|