File size: 5,456 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import importlib
import os
import sys

import mmengine
import torch
import torch.nn as nn
from mmengine.device import get_device
from transformers import StoppingCriteria

from opencompass.registry import MM_MODELS

IMAGE_TOKEN_INDEX = -200


def load_package():
    """Load required packages from LLaVA."""
    current_file_path = os.path.abspath(__file__)
    current_folder_path = os.path.dirname(current_file_path)

    sys.path.append(os.path.join(current_folder_path, 'LLaVA'))  # noqa
    return


class KeywordsStoppingCriteria(StoppingCriteria):
    """Keyword stopping criteria implemented for llava."""

    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.tokenizer = tokenizer
        self.start_len = None
        self.input_ids = input_ids

    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor,
                 **kwargs) -> bool:
        if self.start_len is None:
            self.start_len = self.input_ids.shape[1]
        else:
            outputs = self.tokenizer.batch_decode(output_ids[:,
                                                             self.start_len:],
                                                  skip_special_tokens=True)[0]
            for keyword in self.keywords:
                if keyword in outputs:
                    return True
        return False


@MM_MODELS.register_module('llava')
class LLaVA(nn.Module):
    """Inference code of LLaVA. Need to clone LLaVA official repo first. Please
    check out the README in config.

    Args:
        model_path (str): The path of llava checkpoint.
        prompt_constructor (dict): The config of prompt constructor.
        post_processor (dict): The config of post processor.
        is_caption_task (bool): Whether the task is caption task.
            Defaults to False.
    """

    def __init__(
        self,
        model_path: str,
        prompt_constructor: dict,
        post_processor: dict,
        is_caption_task: bool = False,
    ) -> None:
        super().__init__()
        self.dtype = torch.float16
        self.is_caption_task = is_caption_task

        # load LLaVA modules
        load_package()
        mm_utils = importlib.import_module('llava.mm_utils')
        builder = importlib.import_module('llava.model.builder')

        # load pretrained LLaVA
        # Note: When encounters with device related errors,
        # try setting `low_cpu_mem_usage` in `load_pretrained_model` as False
        model_name = mm_utils.get_model_name_from_path(model_path)
        tokenizer, model, _, _ = builder.load_pretrained_model(
            model_path, None, model_name)
        vision_tower = model.get_vision_tower()
        vision_tower.to(device=get_device(), dtype=self.dtype)
        model.to(device=get_device(), dtype=self.dtype)

        # load prompt constructor and post processor
        if 'v1' in model_path.lower():
            conv_mode = 'llava_v1'
        elif 'mpt' in model_path.lower():
            conv_mode = 'mpt_multimodal'
        else:
            conv_mode = 'multimodal'
        mm_use_im_start_end = getattr(model.config, 'mm_use_im_start_end',
                                      False)
        prompt_constructor.update({
            'conv_mode': conv_mode,
            'mm_use_im_start_end': mm_use_im_start_end
        })
        self.prompt_constructor = mmengine.registry.build_from_cfg(
            prompt_constructor, MM_MODELS)
        self.post_processor = mmengine.registry.build_from_cfg(
            post_processor, MM_MODELS)
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, batch):

        prompt, stop_str = self.prompt_constructor(batch)
        keywords = [stop_str]
        data_sample = batch['data_samples'][0]

        image = batch['inputs'][0].unsqueeze(0)
        if image is not None:
            images = image.to(get_device())
        else:
            images = None

        mm_utils = importlib.import_module('llava.mm_utils')
        input_ids = mm_utils.tokenizer_image_token(
            prompt, self.tokenizer, IMAGE_TOKEN_INDEX,
            return_tensors='pt').unsqueeze(0).to(get_device())

        stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer,
                                                     input_ids)

        with torch.inference_mode():
            output_ids = self.model.generate(
                input_ids,
                images=images.half(),
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                stopping_criteria=[stopping_criteria],
            )

        input_token_len = input_ids.shape[1]
        n_diff_input_output = (input_ids !=
                               output_ids[:, :input_token_len]).sum().item()
        if n_diff_input_output > 0:
            print(
                f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids'  # noqa
            )
        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
                                              skip_special_tokens=True)[0]

        output_text = self.post_processor(outputs, stop_str)

        if self.is_caption_task:
            data_sample.pred_caption = output_text
        else:
            data_sample.pred_answer = output_text
        return data_sample

    def forward(self, batch):
        return self.generate(batch)