HODACHI commited on
Commit
7742501
1 Parent(s): 0117a0a

First Commit

Browse files
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: V1
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
+ ---
2
+ title: Llama-3-EvoVLM-JP-v2
3
+ emoji: 🐠
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,63 +1,183 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.utils import async_lambda
3
+ import spaces
4
+ import time
5
+ import subprocess
6
+
7
+ import torch
8
+
9
+ from models.mllava import (
10
+ MLlavaProcessor,
11
+ LlavaForConditionalGeneration,
12
+ prepare_inputs,
13
+ )
14
+ from models.conversation import Conversation, SeparatorStyle
15
+ from transformers import TextIteratorStreamer
16
+ from transformers.utils import is_flash_attn_2_available
17
+ from threading import Thread
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ IMAGE_TOKEN = "<image>"
21
+ generation_kwargs = {
22
+ "max_new_tokens": 128,
23
+ "num_beams": 1,
24
+ "do_sample": False,
25
+ "no_repeat_ngram_size": 3,
26
+ }
27
+ if device == "cpu":
28
+ processor = None
29
+ model = None
30
+ else:
31
+ if not is_flash_attn_2_available():
32
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
33
+
34
+ processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
35
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
36
+
37
+ model = LlavaForConditionalGeneration.from_pretrained(
38
+ "HODACHI/Llama-3-EZO-VLM-1",
39
+ torch_dtype=torch.float16,
40
+ attn_implementation="flash_attention_2",
41
+ device_map=device,
42
+ ).eval()
43
+
44
+ # Set the system prompt
45
+ conv_template = Conversation(
46
+ system="<|start_header_id|>system<|end_header_id|>\n\nあなたは誠実で優秀な日本人のアシスタントです。特に指示が無い場合は、常に日本語で回答してください。",
47
+ roles=("user", "assistant"),
48
+ messages=(),
49
+ offset=0,
50
+ sep_style=SeparatorStyle.LLAMA_3,
51
+ sep="<|eot_id|>",
52
+ )
53
+
54
+
55
+ def get_chat_messages(history):
56
+ chat_history = []
57
+ user_role = conv_template.roles[0]
58
+ assistant_role = conv_template.roles[1]
59
+ for i, message in enumerate(history):
60
+ if isinstance(message[0], str):
61
+ chat_history.append({"role": user_role, "text": message[0]})
62
+ if i != len(history) - 1:
63
+ assert message[1], "The bot message is not provided, internal error"
64
+ chat_history.append({"role": assistant_role, "text": message[1]})
65
+ else:
66
+ assert not message[1], "the bot message internal error, get: {}".format(
67
+ message[1]
68
+ )
69
+ chat_history.append({"role": assistant_role, "text": ""})
70
+ return chat_history
71
+
72
+
73
+ def get_chat_images(history):
74
+ images = []
75
+ for message in history:
76
+ if isinstance(message[0], tuple):
77
+ images.extend(message[0])
78
+ return images
79
+
80
+
81
+ @spaces.GPU
82
+ def bot(message, history):
83
+ if not model:
84
+ print(message, history)
85
+ images = message["files"] if message["files"] else None
86
+ text = message["text"].strip()
87
+ if not text:
88
+ raise gr.Error("You must enter a message!")
89
+ num_image_tokens = text.count(IMAGE_TOKEN)
90
+ # modify text
91
+ if images and num_image_tokens < len(images):
92
+ if num_image_tokens != 0:
93
+ gr.Warning(
94
+ "The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text."
95
+ )
96
+ # prefix image tokens
97
+ text = IMAGE_TOKEN * (len(images) - num_image_tokens) + text
98
+ if images and num_image_tokens > len(images):
99
+ raise gr.Error(
100
+ "The number of images uploaded is less than the number of <image> placeholders in the text!"
101
+ )
102
+
103
+ current_messages = []
104
+ if images:
105
+ current_messages += [[(image,), None] for image in images]
106
+ if text:
107
+ current_messages += [[text, None]]
108
+ current_history = history + current_messages
109
+ chat_messages = get_chat_messages(current_history)
110
+ chat_images = get_chat_images(current_history)
111
+
112
+ # Generate!
113
+ inputs = prepare_inputs(None, chat_images, model, processor, history=chat_messages, **generation_kwargs)
114
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
115
+ inputs["streamer"] = streamer
116
+ thread = Thread(target=model.generate, kwargs=inputs)
117
+ thread.start()
118
+ buffer = ""
119
+ for new_text in streamer:
120
+ buffer += new_text
121
+ time.sleep(0.01)
122
+ yield buffer
123
+
124
+
125
+ DESCRIPTION = """# 🐟 Llama-3-EZO-VLM-1
126
+ 🤗 [モデル一覧](https://huggingface.co/HODACHI) |
127
+
128
+
129
+ [Llama-3-EZO-VLM-1](https://huggingface.co/HODACHI/Llama-3-EZO-VLM-1)は[Axcxept co., ltd.](https://axcxept.com/)が
130
+ [SakanaAI/Llama-3-EvoVLM-JP-v2](https://huggingface.co/SakanaAI/Llama-3-EvoVLM-JP-v2)をベースに性能向上を行った視覚言語モデルです。
131
+ """
132
+
133
+ examples = [
134
+ {
135
+ "text": "1番目と2番目の画像に写っている動物の違いは何ですか?簡潔に説明してください。",
136
+ "files": ["./examples/image_0.jpg", "./examples/image_1.jpg"],
137
+ },
138
+ {
139
+ "text": "2枚の写真について、簡単にそれぞれ説明してください。",
140
+ "files": ["./examples/image_2.jpg", "./examples/image_3.jpg"],
141
+ },
142
+ ]
143
+
144
+
145
+ chat = gr.ChatInterface(
146
+ fn=bot,
147
+ multimodal=True,
148
+ chatbot=gr.Chatbot(label="Chatbot", scale=1, height=500),
149
+ textbox=gr.MultimodalTextbox(
150
+ interactive=True,
151
+ file_types=["image"],
152
+ # file_count="multiple",
153
+ placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images",
154
+ show_label=True,
155
+ ),
156
+ examples=examples,
157
+ fill_height=False,
158
+ stop_btn=None,
159
+ )
160
+ with gr.Blocks(fill_height=True) as demo:
161
+ gr.Markdown(DESCRIPTION)
162
+ chat.render()
163
+ chat.examples_handler.load_input_event.then(
164
+ fn=async_lambda(lambda: [[], [], None]),
165
+ outputs=[chat.chatbot, chat.chatbot_state, chat.saved_input],
166
+ )
167
+
168
+ gr.Markdown(
169
+ """
170
+ ### チャットの方法
171
+ HODACHI/Llama-3-EZO-VLM-1は、画像をテキストの好きな場所に入力として配置することができます。画像をアップロードする場所は、`<image>`というフレーズで指定できます。
172
+ モデルの推論時に、自動的に`<image>`が画像トークンに置き換えられます。また、画像のアップロード数が`<image>`の数よりも少ない場合、余分な`<image>`が削除されます。
173
+ 逆に、画像のアップロード数が`<image>`の数よりも多い場合、自動的に`<image>`が追加されます。
174
+
175
+ ### 注意事項
176
+ 本モデルは実験段階のプロトタイプであり、研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。
177
+ 本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。
178
+ Axcxept co., ltd.は、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
179
+ 利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。
180
+ また、このデモでは、できる限り多くの皆様にお使いいただけるように、出力テキストのサイズを制限しております。"""
181
+ )
182
+
183
+ demo.queue().launch()
examples/image_0.jpg ADDED
examples/image_1.jpg ADDED
examples/image_2.jpg ADDED
examples/image_3.jpg ADDED
models/conversation.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script includes codes copied directly from https://huggingface.co/spaces/TIGER-Lab/Mantis
2
+
3
+ import dataclasses
4
+ from enum import auto, Enum
5
+ from typing import List, Tuple
6
+
7
+
8
+ class SeparatorStyle(Enum):
9
+ """Different separator style."""
10
+ SINGLE = auto()
11
+ TWO = auto()
12
+ MPT = auto()
13
+ PLAIN = auto()
14
+ LLAMA_2 = auto()
15
+ LLAMA_3 = auto()
16
+ MFuyu = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ skip_next: bool = False
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+
37
+ messages = self.messages.copy()
38
+ init_role, init_msg = messages[0].copy()
39
+ init_msg = init_msg[0].replace("<image>", "").strip()
40
+ if 'mmtag' in self.version:
41
+ messages[0] = (init_role, init_msg)
42
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
43
+ messages.insert(1, (self.roles[1], "Received."))
44
+ else:
45
+ messages[0] = (init_role, "<image>" + init_msg)
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message, _, _ = message
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message, _, _ = message
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
96
+ ret = self.system + self.sep
97
+ for role, message in messages:
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + message + self.sep
102
+ else:
103
+ ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
104
+ elif self.sep_style == SeparatorStyle.MFuyu:
105
+ seps = [self.sep, self.sep2]
106
+ ret = self.system + "\n"
107
+ for i, (role, message) in enumerate(messages):
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, _, _ = message
111
+ ret += role + ": " + message + seps[i % 2]
112
+ else:
113
+ ret += role + ":"
114
+ elif self.sep_style == SeparatorStyle.PLAIN:
115
+ seps = [self.sep, self.sep2]
116
+ ret = self.system
117
+ for i, (role, message) in enumerate(messages):
118
+ if message:
119
+ if type(message) is tuple:
120
+ message, _, _ = message
121
+ ret += message + seps[i % 2]
122
+ else:
123
+ ret += ""
124
+ else:
125
+ raise ValueError(f"Invalid style: {self.sep_style}")
126
+
127
+ return ret
128
+
129
+ def append_message(self, role, message):
130
+ self.messages.append([role, message])
131
+
132
+ def get_images(self, return_pil=False):
133
+ images = []
134
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
135
+ if i % 2 == 0:
136
+ if type(msg) is tuple:
137
+ import base64
138
+ from io import BytesIO
139
+ from PIL import Image
140
+ msg, image, image_process_mode = msg
141
+ if image_process_mode == "Pad":
142
+ def expand2square(pil_img, background_color=(122, 116, 104)):
143
+ width, height = pil_img.size
144
+ if width == height:
145
+ return pil_img
146
+ elif width > height:
147
+ result = Image.new(pil_img.mode, (width, width), background_color)
148
+ result.paste(pil_img, (0, (width - height) // 2))
149
+ return result
150
+ else:
151
+ result = Image.new(pil_img.mode, (height, height), background_color)
152
+ result.paste(pil_img, ((height - width) // 2, 0))
153
+ return result
154
+ image = expand2square(image)
155
+ elif image_process_mode in ["Default", "Crop"]:
156
+ pass
157
+ elif image_process_mode == "Resize":
158
+ image = image.resize((336, 336))
159
+ else:
160
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
161
+ max_hw, min_hw = max(image.size), min(image.size)
162
+ aspect_ratio = max_hw / min_hw
163
+ max_len, min_len = 800, 400
164
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
165
+ longest_edge = int(shortest_edge * aspect_ratio)
166
+ W, H = image.size
167
+ if longest_edge != max(image.size):
168
+ if H > W:
169
+ H, W = longest_edge, shortest_edge
170
+ else:
171
+ H, W = shortest_edge, longest_edge
172
+ image = image.resize((W, H))
173
+ if return_pil:
174
+ images.append(image)
175
+ else:
176
+ buffered = BytesIO()
177
+ image.save(buffered, format="PNG")
178
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
179
+ images.append(img_b64_str)
180
+ return images
181
+
182
+ def to_gradio_chatbot(self):
183
+ ret = []
184
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
185
+ if i % 2 == 0:
186
+ if type(msg) is tuple:
187
+ import base64
188
+ from io import BytesIO
189
+ msg, image, image_process_mode = msg
190
+ max_hw, min_hw = max(image.size), min(image.size)
191
+ aspect_ratio = max_hw / min_hw
192
+ max_len, min_len = 800, 400
193
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
194
+ longest_edge = int(shortest_edge * aspect_ratio)
195
+ W, H = image.size
196
+ if H > W:
197
+ H, W = longest_edge, shortest_edge
198
+ else:
199
+ H, W = shortest_edge, longest_edge
200
+ image = image.resize((W, H))
201
+ buffered = BytesIO()
202
+ image.save(buffered, format="JPEG")
203
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
204
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
205
+ msg = img_str + msg.replace('<image>', '').strip()
206
+ ret.append([msg, None])
207
+ else:
208
+ ret.append([msg, None])
209
+ else:
210
+ ret[-1][-1] = msg
211
+ return ret
212
+
213
+ def copy(self):
214
+ return Conversation(
215
+ system=self.system,
216
+ roles=self.roles,
217
+ messages=[[x, y] for x, y in self.messages],
218
+ offset=self.offset,
219
+ sep_style=self.sep_style,
220
+ sep=self.sep,
221
+ sep2=self.sep2,
222
+ version=self.version)
223
+
224
+ def dict(self):
225
+ if len(self.get_images()) > 0:
226
+ return {
227
+ "system": self.system,
228
+ "roles": self.roles,
229
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
230
+ "offset": self.offset,
231
+ "sep": self.sep,
232
+ "sep2": self.sep2,
233
+ }
234
+ return {
235
+ "system": self.system,
236
+ "roles": self.roles,
237
+ "messages": self.messages,
238
+ "offset": self.offset,
239
+ "sep": self.sep,
240
+ "sep2": self.sep2,
241
+ }
242
+
243
+
244
+ conv_vicuna_v0 = Conversation(
245
+ system="A chat between a curious human and an artificial intelligence assistant. "
246
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
247
+ roles=("Human", "Assistant"),
248
+ messages=(
249
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
250
+ ("Assistant",
251
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
252
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
253
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
254
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
255
+ "renewable and non-renewable energy sources:\n"
256
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
257
+ "energy sources are finite and will eventually run out.\n"
258
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
259
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
260
+ "and other negative effects.\n"
261
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
262
+ "have lower operational costs than non-renewable sources.\n"
263
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
264
+ "locations than non-renewable sources.\n"
265
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
266
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
267
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
268
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
269
+ ),
270
+ offset=2,
271
+ sep_style=SeparatorStyle.SINGLE,
272
+ sep="###",
273
+ )
274
+
275
+ conv_vicuna_v1 = Conversation(
276
+ system="A chat between a curious user and an artificial intelligence assistant. "
277
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
278
+ roles=("USER", "ASSISTANT"),
279
+ version="v1",
280
+ messages=(),
281
+ offset=0,
282
+ sep_style=SeparatorStyle.TWO,
283
+ sep=" ",
284
+ sep2="</s>",
285
+ )
286
+
287
+ conv_llama_2 = Conversation(
288
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
289
+
290
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
291
+ roles=("USER", "ASSISTANT"),
292
+ version="llama_v2",
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.LLAMA_2,
296
+ sep="<s>",
297
+ sep2="</s>",
298
+ )
299
+
300
+ conv_llava_llama_2 = Conversation(
301
+ system="You are a helpful language and vision assistant. "
302
+ "You are able to understand the visual content that the user provides, "
303
+ "and assist the user with a variety of tasks using natural language.",
304
+ roles=("USER", "ASSISTANT"),
305
+ version="llama_v2",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.LLAMA_2,
309
+ sep="<s>",
310
+ sep2="</s>",
311
+ )
312
+
313
+ conv_mpt = Conversation(
314
+ system="""<|im_start|>system
315
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
316
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
317
+ version="mpt",
318
+ messages=(),
319
+ offset=0,
320
+ sep_style=SeparatorStyle.MPT,
321
+ sep="<|im_end|>",
322
+ )
323
+
324
+ conv_llava_plain = Conversation(
325
+ system="",
326
+ roles=("", ""),
327
+ messages=(
328
+ ),
329
+ offset=0,
330
+ sep_style=SeparatorStyle.PLAIN,
331
+ sep="\n",
332
+ )
333
+
334
+ conv_llava_v0 = Conversation(
335
+ system="A chat between a curious human and an artificial intelligence assistant. "
336
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
337
+ roles=("Human", "Assistant"),
338
+ messages=(
339
+ ),
340
+ offset=0,
341
+ sep_style=SeparatorStyle.SINGLE,
342
+ sep="###",
343
+ )
344
+
345
+ conv_llava_v0_mmtag = Conversation(
346
+ system="A chat between a curious user and an artificial intelligence assistant. "
347
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
348
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
349
+ roles=("Human", "Assistant"),
350
+ messages=(
351
+ ),
352
+ offset=0,
353
+ sep_style=SeparatorStyle.SINGLE,
354
+ sep="###",
355
+ version="v0_mmtag",
356
+ )
357
+
358
+ conv_llava_v1 = Conversation(
359
+ system="A chat between a curious human and an artificial intelligence assistant. "
360
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
361
+ roles=("USER", "ASSISTANT"),
362
+ version="v1",
363
+ messages=(),
364
+ offset=0,
365
+ sep_style=SeparatorStyle.TWO,
366
+ sep=" ",
367
+ sep2="</s>",
368
+ )
369
+
370
+ conv_llava_v1_mmtag = Conversation(
371
+ system="A chat between a curious user and an artificial intelligence assistant. "
372
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
373
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
374
+ roles=("USER", "ASSISTANT"),
375
+ messages=(),
376
+ offset=0,
377
+ sep_style=SeparatorStyle.TWO,
378
+ sep=" ",
379
+ sep2="</s>",
380
+ version="v1_mmtag",
381
+ )
382
+
383
+ conv_mfuyu_v1 = Conversation(
384
+ system="You are a helpful language and vision assistant. "
385
+ "You are able to understand the visual content that the user provides, "
386
+ "and assist the user with a variety of tasks using natural language.",
387
+ roles=("USER", "ASSISTANT"),
388
+ version="v1",
389
+ messages=(),
390
+ offset=0,
391
+ sep_style=SeparatorStyle.MFuyu,
392
+ sep="<0x04>", # begin of answer token
393
+ sep2="|ENDOFTEXT|",
394
+ ) # copied from conv_vicuna_v1
395
+
396
+ conv_mllava_v1_mmtag = Conversation(
397
+ system="A chat between a curious user and an artificial intelligence assistant. "
398
+ "The assistant is able to understand the multiple visual contents that the user provides, and assist the user with a variety of tasks using natural language."
399
+ "Each visual content will be provided with the following format: <Image>visual content</Image>.",
400
+ roles=("USER", "ASSISTANT"),
401
+ messages=(),
402
+ offset=0,
403
+ sep_style=SeparatorStyle.SINGLE,
404
+ sep="</s>",
405
+ version="v1_mmtag",
406
+ )
407
+
408
+ conv_mllava_v1 = Conversation(
409
+ system="A chat between a curious human and an artificial intelligence assistant. "
410
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
411
+ roles=("USER", "ASSISTANT"),
412
+ version="v1",
413
+ messages=(),
414
+ offset=0,
415
+ sep_style=SeparatorStyle.SINGLE,
416
+ sep="</s>",
417
+ )
418
+
419
+ conv_llama_3 = Conversation(
420
+ system="<|start_header_id|>system<|end_header_id|>\n\nYou are a pirate chatbot who always responds in pirate speak!",
421
+ roles=("user", "assistant"),
422
+ messages=(),
423
+ offset=0,
424
+ sep_style=SeparatorStyle.LLAMA_3,
425
+ sep="<|eot_id|>",
426
+ )
427
+
428
+ default_conversation = conv_mfuyu_v1
429
+ conv_templates = {
430
+ "default": conv_vicuna_v0,
431
+ "v0": conv_vicuna_v0,
432
+ "v1": conv_vicuna_v1,
433
+ "vicuna_v1": conv_vicuna_v1,
434
+ "llama_2": conv_llama_2,
435
+
436
+ "plain": conv_llava_plain,
437
+ "v0_plain": conv_llava_plain,
438
+ "llava_v0": conv_llava_v0,
439
+ "v0_mmtag": conv_llava_v0_mmtag,
440
+ "llava_v1": conv_llava_v1,
441
+ "v1_mmtag": conv_llava_v1_mmtag,
442
+ "llava_llama_2": conv_llava_llama_2,
443
+ "llama_3": conv_llama_3,
444
+ "mllava_v1": conv_mllava_v1,
445
+ "mllava_v1_mmtag": conv_mllava_v1_mmtag,
446
+
447
+ "mpt": conv_mpt,
448
+ }
449
+
450
+
451
+ if __name__ == "__main__":
452
+ print(default_conversation.get_prompt())
models/mllava/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # This script includes codes copied directly from https://huggingface.co/spaces/TIGER-Lab/Mantis
2
+
3
+ from .modeling_llava import LlavaForConditionalGeneration, MLlavaForConditionalGeneration
4
+ from .processing_llava import MLlavaProcessor
5
+ from .configuration_llava import LlavaConfig
6
+ from .utils import chat_mllava, prepare_inputs
models/mllava/configuration_llava.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This script includes codes copied directly from https://huggingface.co/spaces/TIGER-Lab/Mantis
16
+
17
+ """ Llava model configuration"""
18
+
19
+
20
+ # from ...configuration_utils import PretrainedConfig
21
+ # from ...utils import logging
22
+ # from ..auto import CONFIG_MAPPING
23
+ from transformers.configuration_utils import PretrainedConfig
24
+ from transformers.utils import logging
25
+ from transformers.models.auto import CONFIG_MAPPING
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ LLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
31
+ "llava-hf/llava-v1.5-7b": "https://huggingface.co/llava-hf/llava-v1.5-7b/resolve/main/config.json",
32
+ }
33
+
34
+
35
+ class LlavaConfig(PretrainedConfig):
36
+ r"""
37
+ This is the configuration class to store the configuration of a [`LlavaForConditionalGeneration`]. It is used to instantiate an
38
+ Llava model according to the specified arguments, defining the model architecture. Instantiating a configuration
39
+ with the defaults will yield a similar configuration to that of the Llava-9B.
40
+
41
+ e.g. [llava-hf/llava-9b](https://huggingface.co/llava-hf/llava-9b)
42
+
43
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
44
+ documentation from [`PretrainedConfig`] for more information.
45
+
46
+ Args:
47
+ vision_config (`LlavaVisionConfig`, *optional*):
48
+ Custom vision config or dict
49
+ text_config (`Union[AutoConfig, dict]`, *optional*):
50
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
51
+ ignore_index (`int`, *optional*, defaults to -100):
52
+ The ignore index for the loss function.
53
+ image_token_index (`int`, *optional*, defaults to 32000):
54
+ The image token index to encode the image prompt.
55
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
56
+ The activation function used by the multimodal projector.
57
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
58
+ The feature selection strategy used to select the vision feature from the CLIP backbone.
59
+ vision_feature_layer (`int`, *optional*, defaults to -2):
60
+ The index of the layer to select the vision feature.
61
+ vocab_size (`int`, *optional*, defaults to 32000):
62
+ Vocabulary size of the Llava model. Defines the number of different tokens that can be represented by the
63
+ `inputs_ids` passed when calling [`~LlavaForConditionalGeneration`]
64
+
65
+ Example:
66
+
67
+ ```python
68
+ >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig
69
+
70
+ >>> # Initializing a CLIP-vision config
71
+ >>> vision_config = CLIPVisionConfig()
72
+
73
+ >>> # Initializing a Llama config
74
+ >>> text_config = LlamaConfig()
75
+
76
+ >>> # Initializing a Llava llava-1.5-7b style configuration
77
+ >>> configuration = LlavaConfig(vision_config, text_config)
78
+
79
+ >>> # Initializing a model from the llava-1.5-7b style configuration
80
+ >>> model = LlavaForConditionalGeneration(configuration)
81
+
82
+ >>> # Accessing the model configuration
83
+ >>> configuration = model.config
84
+ ```"""
85
+
86
+ model_type = "llava"
87
+ is_composition = False
88
+
89
+ def __init__(
90
+ self,
91
+ vision_config=None,
92
+ text_config=None,
93
+ ignore_index=-100,
94
+ image_token_index=32000,
95
+ projector_hidden_act="gelu",
96
+ vision_feature_select_strategy="default",
97
+ vision_feature_layer=-2,
98
+ vocab_size=32000,
99
+ **kwargs,
100
+ ):
101
+ self.ignore_index = ignore_index
102
+ self.image_token_index = image_token_index
103
+ self.projector_hidden_act = projector_hidden_act
104
+ self.vision_feature_select_strategy = vision_feature_select_strategy
105
+ self.vision_feature_layer = vision_feature_layer
106
+ self.vocab_size = vocab_size
107
+
108
+ self.vision_config = vision_config
109
+
110
+ if isinstance(self.vision_config, dict):
111
+ vision_config["model_type"] = (
112
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
113
+ )
114
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
115
+ elif vision_config is None:
116
+ self.vision_config = CONFIG_MAPPING["clip_vision_model"](
117
+ intermediate_size=4096,
118
+ hidden_size=1024,
119
+ patch_size=14,
120
+ image_size=336,
121
+ num_hidden_layers=24,
122
+ num_attention_heads=16,
123
+ vocab_size=32000,
124
+ projection_dim=768,
125
+ )
126
+ self.vocab_size = self.vocab_size
127
+
128
+ self.text_config = text_config
129
+
130
+ if isinstance(self.text_config, dict):
131
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
132
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
133
+ self.vocab_size = self.text_config.vocab_size
134
+ elif text_config is None:
135
+ self.text_config = CONFIG_MAPPING["llama"]()
136
+
137
+ super().__init__(**kwargs)
models/mllava/modeling_llava.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This script includes codes copied directly from https://huggingface.co/spaces/TIGER-Lab/Mantis
17
+
18
+ """ PyTorch Llava model."""
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+
26
+ # from ... import PreTrainedModel
27
+ # from ...activations import ACT2FN
28
+ # from ...cache_utils import Cache
29
+ # from ...modeling_outputs import ModelOutput
30
+ # from ...utils import (
31
+ # add_start_docstrings,
32
+ # add_start_docstrings_to_model_forward,
33
+ # logging,
34
+ # replace_return_docstrings,
35
+ # )
36
+ # from ..auto import AutoModel, AutoModelForCausalLM
37
+
38
+ from .configuration_llava import LlavaConfig
39
+
40
+ from transformers import PreTrainedModel
41
+ from transformers.activations import ACT2FN
42
+ from transformers.cache_utils import Cache
43
+ from transformers.modeling_outputs import ModelOutput
44
+ from transformers.utils import (
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ logging,
48
+ replace_return_docstrings,
49
+ )
50
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
51
+ from .configuration_llava import LlavaConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+ _CONFIG_FOR_DOC = "LlavaConfig"
57
+
58
+ LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = [
59
+ "llava-hf/llava-1.5-7b-hf",
60
+ "llava-hf/llava-1.5-13b-hf",
61
+ "llava-hf/bakLlava-v1-hf",
62
+ # See all Llava models at https://huggingface.co/models?filter=llava
63
+ ]
64
+
65
+
66
+ @dataclass
67
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
68
+ class LlavaCausalLMOutputWithPast(ModelOutput):
69
+ """
70
+ Base class for Llava causal language model (or autoregressive) outputs.
71
+
72
+ Args:
73
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
74
+ Language modeling loss (for next-token prediction).
75
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
76
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
77
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
78
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
79
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
80
+
81
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
82
+ `past_key_values` input) to speed up sequential decoding.
83
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
84
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
85
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
86
+
87
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
88
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
89
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
90
+ sequence_length)`.
91
+
92
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
93
+ heads.
94
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
95
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
96
+ sequence_length, hidden_size)`.
97
+
98
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
99
+ """
100
+
101
+ loss: Optional[torch.FloatTensor] = None
102
+ logits: torch.FloatTensor = None
103
+ past_key_values: Optional[List[torch.FloatTensor]] = None
104
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
105
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
106
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
107
+
108
+
109
+ class LlavaMultiModalProjector(nn.Module):
110
+ def __init__(self, config: LlavaConfig):
111
+ super().__init__()
112
+
113
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
114
+ self.act = ACT2FN[config.projector_hidden_act]
115
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
116
+
117
+ def forward(self, image_features):
118
+ hidden_states = self.linear_1(image_features)
119
+ hidden_states = self.act(hidden_states)
120
+ hidden_states = self.linear_2(hidden_states)
121
+ return hidden_states
122
+
123
+
124
+ LLAVA_START_DOCSTRING = r"""
125
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
126
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
127
+ etc.)
128
+
129
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
130
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
131
+ and behavior.
132
+
133
+ Parameters:
134
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
135
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
136
+ load the weights associated with the model, only the configuration. Check out the
137
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
138
+ """
139
+
140
+
141
+ @add_start_docstrings(
142
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
143
+ LLAVA_START_DOCSTRING,
144
+ )
145
+ class LlavaPreTrainedModel(PreTrainedModel):
146
+ config_class = LlavaConfig
147
+ base_model_prefix = "model"
148
+ supports_gradient_checkpointing = True
149
+ _no_split_modules = ["LlavaVisionAttention"]
150
+ _skip_keys_device_placement = "past_key_values"
151
+ _supports_flash_attn_2 = True
152
+
153
+ def _init_weights(self, module):
154
+ # important: this ported version of Llava isn't meant for training from scratch - only
155
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
156
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava should serve for that purpose
157
+ std = (
158
+ self.config.initializer_range
159
+ if hasattr(self.config, "initializer_range")
160
+ else self.config.text_config.initializer_range
161
+ )
162
+
163
+ if hasattr(module, "class_embedding"):
164
+ module.class_embedding.data.normal_(mean=0.0, std=std)
165
+
166
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
167
+ module.weight.data.normal_(mean=0.0, std=std)
168
+ if module.bias is not None:
169
+ module.bias.data.zero_()
170
+ elif isinstance(module, nn.Embedding):
171
+ module.weight.data.normal_(mean=0.0, std=std)
172
+ if module.padding_idx is not None:
173
+ module.weight.data[module.padding_idx].zero_()
174
+
175
+ @property
176
+ def _supports_sdpa(self):
177
+ """
178
+ Retrieve language_model's attribute to check whether the model supports
179
+ SDPA or not.
180
+ """
181
+ return self.language_model._supports_sdpa
182
+
183
+
184
+ LLAVA_INPUTS_DOCSTRING = r"""
185
+ Args:
186
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
187
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
188
+ it.
189
+
190
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
191
+ [`PreTrainedTokenizer.__call__`] for details.
192
+
193
+ [What are input IDs?](../glossary#input-ids)
194
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
195
+ The tensors corresponding to the input images. Pixel values can be obtained using
196
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses
197
+ [`CLIPImageProcessor`] for processing images).
198
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
199
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
200
+
201
+ - 1 for tokens that are **not masked**,
202
+ - 0 for tokens that are **masked**.
203
+
204
+ [What are attention masks?](../glossary#attention-mask)
205
+
206
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
207
+ [`PreTrainedTokenizer.__call__`] for details.
208
+
209
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
210
+ `past_key_values`).
211
+
212
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
213
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
214
+ information on the default strategy.
215
+
216
+ - 1 indicates the head is **not masked**,
217
+ - 0 indicates the head is **masked**.
218
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
219
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
220
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
221
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
222
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
223
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
224
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
225
+
226
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
227
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
228
+
229
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
230
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
231
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
232
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
233
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
234
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
235
+ model's internal embedding lookup matrix.
236
+ use_cache (`bool`, *optional*):
237
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
238
+ `past_key_values`).
239
+ output_attentions (`bool`, *optional*):
240
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
241
+ tensors for more detail.
242
+ output_hidden_states (`bool`, *optional*):
243
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
244
+ more detail.
245
+ return_dict (`bool`, *optional*):
246
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
247
+ """
248
+
249
+
250
+ @add_start_docstrings(
251
+ """The LLAVA model which consists of a vision backbone and a language model.""",
252
+ LLAVA_START_DOCSTRING,
253
+ )
254
+ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
255
+ def __init__(self, config: LlavaConfig, vision_tower=None, language_model=None):
256
+ super().__init__(config)
257
+ self.vision_tower = AutoModel.from_config(config.vision_config) if vision_tower is None else vision_tower
258
+
259
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
260
+ self.vocab_size = config.vocab_size
261
+ self.language_model = AutoModelForCausalLM.from_config(
262
+ config.text_config, attn_implementation=config._attn_implementation
263
+ ) if language_model is None else language_model
264
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
265
+ self.post_init()
266
+
267
+ def get_input_embeddings(self):
268
+ return self.language_model.get_input_embeddings()
269
+
270
+ def set_input_embeddings(self, value):
271
+ self.language_model.set_input_embeddings(value)
272
+
273
+ def get_output_embeddings(self):
274
+ return self.language_model.get_output_embeddings()
275
+
276
+ def set_output_embeddings(self, new_embeddings):
277
+ self.language_model.set_output_embeddings(new_embeddings)
278
+
279
+ def set_decoder(self, decoder):
280
+ self.language_model.set_decoder(decoder)
281
+
282
+ def get_decoder(self):
283
+ return self.language_model.get_decoder()
284
+
285
+ def tie_weights(self):
286
+ return self.language_model.tie_weights()
287
+
288
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
289
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
290
+ # update vocab size
291
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
292
+ self.config.vocab_size = model_embeds.num_embeddings
293
+ self.vocab_size = model_embeds.num_embeddings
294
+ return model_embeds
295
+
296
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
297
+ num_images, num_image_patches, embed_dim = image_features.shape
298
+ batch_size, sequence_length = input_ids.shape
299
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
300
+ # 1. Create a mask to know where special image tokens are
301
+ special_image_token_mask = input_ids == self.config.image_token_index
302
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
303
+ # Compute the maximum embed dimension
304
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
305
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
306
+
307
+ # 2. Compute the positions where text should be written
308
+ # Calculate new positions for text tokens in merged image-text sequence.
309
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
310
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
311
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
312
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
313
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
314
+ if left_padding:
315
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
316
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
317
+
318
+ # 3. Create the full embedding, already padded to the maximum position
319
+ final_embedding = torch.zeros(
320
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
321
+ )
322
+ final_attention_mask = torch.zeros(
323
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
324
+ )
325
+ if labels is not None:
326
+ final_labels = torch.full(
327
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
328
+ )
329
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
330
+ # set the corresponding tensors into their correct target device.
331
+ target_device = inputs_embeds.device
332
+ batch_indices, non_image_indices, text_to_overwrite = (
333
+ batch_indices.to(target_device),
334
+ non_image_indices.to(target_device),
335
+ text_to_overwrite.to(target_device),
336
+ )
337
+ attention_mask = attention_mask.to(target_device)
338
+
339
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
340
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
341
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
342
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
343
+ if labels is not None:
344
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
345
+
346
+ # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
347
+ image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
348
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
349
+
350
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
351
+ raise ValueError(
352
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
353
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
354
+ )
355
+
356
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
357
+ final_attention_mask |= image_to_overwrite
358
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
359
+
360
+ if labels is None:
361
+ final_labels = None
362
+
363
+ return final_embedding, final_attention_mask, final_labels, position_ids
364
+
365
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
366
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
367
+ def forward(
368
+ self,
369
+ input_ids: torch.LongTensor = None,
370
+ pixel_values: torch.FloatTensor = None,
371
+ attention_mask: Optional[torch.Tensor] = None,
372
+ position_ids: Optional[torch.LongTensor] = None,
373
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
374
+ inputs_embeds: Optional[torch.FloatTensor] = None,
375
+ vision_feature_layer: Optional[int] = None,
376
+ vision_feature_select_strategy: Optional[str] = None,
377
+ labels: Optional[torch.LongTensor] = None,
378
+ use_cache: Optional[bool] = None,
379
+ output_attentions: Optional[bool] = None,
380
+ output_hidden_states: Optional[bool] = None,
381
+ return_dict: Optional[bool] = None,
382
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
383
+ r"""
384
+ Args:
385
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
386
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
387
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
388
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
389
+
390
+ Returns:
391
+
392
+ Example:
393
+
394
+ ```python
395
+ >>> from PIL import Image
396
+ >>> import requests
397
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
398
+
399
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
400
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
401
+
402
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
403
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
404
+ >>> image = Image.open(requests.get(url, stream=True).raw)
405
+
406
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
407
+
408
+ >>> # Generate
409
+ >>> generate_ids = model.generate(**inputs, max_length=30)
410
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
411
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
412
+ ```"""
413
+
414
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
415
+ output_hidden_states = (
416
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
417
+ )
418
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
419
+ vision_feature_layer = (
420
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
421
+ )
422
+ vision_feature_select_strategy = (
423
+ vision_feature_select_strategy
424
+ if vision_feature_select_strategy is not None
425
+ else self.config.vision_feature_select_strategy
426
+ )
427
+
428
+ if inputs_embeds is None:
429
+ # 1. Extra the input embeddings
430
+ inputs_embeds = self.get_input_embeddings()(input_ids)
431
+
432
+ # 2. Merge text and images
433
+ if pixel_values is not None and input_ids.shape[1] != 1:
434
+ if isinstance(pixel_values, list):
435
+ pixel_values = torch.cat([x for x in pixel_values if x is not None], dim=0)
436
+ # for siglip, need to transform the pixel_values to the right data type
437
+ if pixel_values.dtype != self.vision_tower.dtype:
438
+ pixel_values = pixel_values.type(self.vision_tower.dtype)
439
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
440
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
441
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
442
+
443
+ if vision_feature_select_strategy == "default":
444
+ selected_image_feature = selected_image_feature[:, 1:]
445
+ elif vision_feature_select_strategy == "full":
446
+ selected_image_feature = selected_image_feature
447
+ else:
448
+ raise ValueError(
449
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
450
+ )
451
+
452
+ image_features = self.multi_modal_projector(selected_image_feature)
453
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
454
+ image_features, inputs_embeds, input_ids, attention_mask, labels
455
+ )
456
+ if labels is None:
457
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
458
+ else:
459
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
460
+ # generation with cache
461
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
462
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
463
+ # that are set to 0
464
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
465
+
466
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
467
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
468
+
469
+ # Get the target length
470
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
471
+
472
+ extended_attention_mask = torch.ones(
473
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
474
+ dtype=attention_mask.dtype,
475
+ device=attention_mask.device,
476
+ )
477
+
478
+ # Filter out only the tokens that can be un-attended, this can happen
479
+ # if one uses Llava + Fused modules where the cache on the
480
+ # first iteration is already big enough, or if one passes custom cache
481
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
482
+ new_batch_index = batch_index[valid_indices]
483
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
484
+
485
+ # Zero-out the places where we don't need to attend
486
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
487
+
488
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
489
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
490
+
491
+ outputs = self.language_model(
492
+ attention_mask=attention_mask,
493
+ position_ids=position_ids,
494
+ past_key_values=past_key_values,
495
+ inputs_embeds=inputs_embeds,
496
+ use_cache=use_cache,
497
+ output_attentions=output_attentions,
498
+ output_hidden_states=output_hidden_states,
499
+ return_dict=return_dict,
500
+ )
501
+
502
+ logits = outputs[0]
503
+
504
+ loss = None
505
+ if labels is not None:
506
+ # Shift so that tokens < n predict n
507
+ if attention_mask is not None:
508
+ shift_attention_mask = attention_mask[..., 1:]
509
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
510
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
511
+ else:
512
+ shift_logits = logits[..., :-1, :].contiguous()
513
+ shift_labels = labels[..., 1:].contiguous()
514
+ # Flatten the tokens
515
+ loss_fct = nn.CrossEntropyLoss()
516
+ loss = loss_fct(
517
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
518
+ )
519
+
520
+ if not return_dict:
521
+ output = (logits,) + outputs[1:]
522
+ return (loss,) + output if loss is not None else output
523
+
524
+ return LlavaCausalLMOutputWithPast(
525
+ loss=loss,
526
+ logits=logits,
527
+ past_key_values=outputs.past_key_values,
528
+ hidden_states=outputs.hidden_states,
529
+ attentions=outputs.attentions,
530
+ )
531
+
532
+ def prepare_inputs_for_generation(
533
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
534
+ ):
535
+ if past_key_values is not None:
536
+ if isinstance(past_key_values, Cache):
537
+ cache_length = past_key_values.get_seq_length()
538
+ past_length = past_key_values.seen_tokens
539
+ else:
540
+ cache_length = past_length = past_key_values[0][0].shape[2]
541
+
542
+ # Keep only the unprocessed tokens:
543
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
544
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
545
+ # input)
546
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
547
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
548
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
549
+ # input_ids based on the past_length.
550
+ elif past_length < input_ids.shape[1]:
551
+ input_ids = input_ids[:, past_length:]
552
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
553
+ elif self.config.image_token_index in input_ids:
554
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
555
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
556
+ # older attention values, as their corresponding values are not part of the input.
557
+ if cache_length < past_length and attention_mask is not None:
558
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
559
+
560
+ position_ids = kwargs.get("position_ids", None)
561
+ if attention_mask is not None and position_ids is None:
562
+ # create position_ids on the fly for batch generation
563
+ position_ids = attention_mask.long().cumsum(-1) - 1
564
+ position_ids.masked_fill_(attention_mask == 0, 1)
565
+ if past_key_values:
566
+ position_ids = position_ids[:, -input_ids.shape[1] :]
567
+
568
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
569
+ if inputs_embeds is not None and past_key_values is None:
570
+ model_inputs = {"inputs_embeds": inputs_embeds}
571
+ else:
572
+ model_inputs = {"input_ids": input_ids}
573
+
574
+ model_inputs.update(
575
+ {
576
+ "position_ids": position_ids,
577
+ "past_key_values": past_key_values,
578
+ "use_cache": kwargs.get("use_cache"),
579
+ "attention_mask": attention_mask,
580
+ "pixel_values": pixel_values,
581
+ }
582
+ )
583
+ return model_inputs
584
+
585
+ def _reorder_cache(self, *args, **kwargs):
586
+ return self.language_model._reorder_cache(*args, **kwargs)
587
+
588
+
589
+
590
+
591
+ from transformers.models.clip.modeling_clip import CLIPEncoderLayer, CLIPEncoder
592
+ @add_start_docstrings(
593
+ """The MLLAVA model which consists of a vision backbone and a language model.""",
594
+ LLAVA_START_DOCSTRING,
595
+ )
596
+ class MLlavaForConditionalGeneration(LlavaForConditionalGeneration):
597
+ def __init__(self, config: LlavaConfig):
598
+ super().__init__(config)
599
+ config.vision_config.type_vocab_size = 144
600
+ self.image_type_embeddings = nn.Embedding(config.vision_config.type_vocab_size, config.vision_config.hidden_size)
601
+ # self.vision_xatten_layers = nn.ModuleList([CLIPEncoderLayer(config.vision_config) for _ in range(config.vision_config.num_hidden_layers)])
602
+ self.vision_xatten_layers = CLIPEncoder(config.vision_config)
603
+
604
+
605
+ @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
606
+ @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
607
+ def forward(
608
+ self,
609
+ input_ids: torch.LongTensor = None,
610
+ pixel_values: torch.FloatTensor = None,
611
+ attention_mask: Optional[torch.Tensor] = None,
612
+ position_ids: Optional[torch.LongTensor] = None,
613
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
614
+ inputs_embeds: Optional[torch.FloatTensor] = None,
615
+ vision_feature_layer: Optional[int] = None,
616
+ vision_feature_select_strategy: Optional[str] = None,
617
+ labels: Optional[torch.LongTensor] = None,
618
+ use_cache: Optional[bool] = None,
619
+ output_attentions: Optional[bool] = None,
620
+ output_hidden_states: Optional[bool] = None,
621
+ return_dict: Optional[bool] = None,
622
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
623
+ r"""
624
+ Args:
625
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
626
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
627
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
628
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
629
+
630
+ Returns:
631
+
632
+ Example:
633
+
634
+ ```python
635
+ >>> from PIL import Image
636
+ >>> import requests
637
+ >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
638
+
639
+ >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
640
+ >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
641
+
642
+ >>> prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
643
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
644
+ >>> image = Image.open(requests.get(url, stream=True).raw)
645
+
646
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
647
+
648
+ >>> # Generate
649
+ >>> generate_ids = model.generate(**inputs, max_length=30)
650
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
651
+ "\nUSER: What's the content of the image?\nASSISTANT: The image features a stop sign on a street corner"
652
+ ```"""
653
+
654
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
655
+ output_hidden_states = (
656
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
+ )
658
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
659
+ vision_feature_layer = (
660
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
661
+ )
662
+ vision_feature_select_strategy = (
663
+ vision_feature_select_strategy
664
+ if vision_feature_select_strategy is not None
665
+ else self.config.vision_feature_select_strategy
666
+ )
667
+
668
+ if inputs_embeds is None:
669
+ # 1. Extra the input embeddings
670
+ inputs_embeds = self.get_input_embeddings()(input_ids)
671
+
672
+ # 2. Merge text and images
673
+ if pixel_values is not None and input_ids.shape[1] != 1:
674
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
675
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
676
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
677
+
678
+ if vision_feature_select_strategy == "default":
679
+ selected_image_feature = selected_image_feature[:, 1:]
680
+ elif vision_feature_select_strategy == "full":
681
+ selected_image_feature = selected_image_feature
682
+ else:
683
+ raise ValueError(
684
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
685
+ )
686
+
687
+ # added by Dongfu
688
+ num_images, num_image_patches, embed_dim = selected_image_feature.shape
689
+ image_type_embeddings = self.image_type_embeddings(torch.arange(num_images, device=selected_image_feature.device))
690
+ selected_image_feature += image_type_embeddings.unsqueeze(1)
691
+ xatten_output = self.vision_xatten_layers(selected_image_feature, attention_mask=None, causal_attention_mask=None)
692
+ selected_image_feature = xatten_output[0]
693
+ # end of added by Dongfu
694
+
695
+ image_features = self.multi_modal_projector(selected_image_feature)
696
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
697
+ image_features, inputs_embeds, input_ids, attention_mask, labels
698
+ )
699
+ if labels is None:
700
+ labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
701
+ else:
702
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
703
+ # generation with cache
704
+ if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
705
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
706
+ # that are set to 0
707
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
708
+
709
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
710
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
711
+
712
+ # Get the target length
713
+ target_seqlen = first_layer_past_key_value.shape[-1] + 1
714
+
715
+ extended_attention_mask = torch.ones(
716
+ (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
717
+ dtype=attention_mask.dtype,
718
+ device=attention_mask.device,
719
+ )
720
+
721
+ # Filter out only the tokens that can be un-attended, this can happen
722
+ # if one uses Llava + Fused modules where the cache on the
723
+ # first iteration is already big enough, or if one passes custom cache
724
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
725
+ new_batch_index = batch_index[valid_indices]
726
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
727
+
728
+ # Zero-out the places where we don't need to attend
729
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
730
+
731
+ attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
732
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
733
+
734
+ outputs = self.language_model(
735
+ attention_mask=attention_mask,
736
+ position_ids=position_ids,
737
+ past_key_values=past_key_values,
738
+ inputs_embeds=inputs_embeds,
739
+ use_cache=use_cache,
740
+ output_attentions=output_attentions,
741
+ output_hidden_states=output_hidden_states,
742
+ return_dict=return_dict,
743
+ )
744
+
745
+ logits = outputs[0]
746
+
747
+ loss = None
748
+ if labels is not None:
749
+ # Shift so that tokens < n predict n
750
+ if attention_mask is not None:
751
+ shift_attention_mask = attention_mask[..., 1:]
752
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
753
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
754
+ else:
755
+ shift_logits = logits[..., :-1, :].contiguous()
756
+ shift_labels = labels[..., 1:].contiguous()
757
+ # Flatten the tokens
758
+ loss_fct = nn.CrossEntropyLoss()
759
+ loss = loss_fct(
760
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
761
+ )
762
+
763
+ if not return_dict:
764
+ output = (logits,) + outputs[1:]
765
+ return (loss,) + output if loss is not None else output
766
+
767
+ return LlavaCausalLMOutputWithPast(
768
+ loss=loss,
769
+ logits=logits,
770
+ past_key_values=outputs.past_key_values,
771
+ hidden_states=outputs.hidden_states,
772
+ attentions=outputs.attentions,
773
+ )
models/mllava/processing_llava.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This script includes codes copied directly from https://huggingface.co/spaces/TIGER-Lab/Mantis
17
+
18
+ """
19
+ Processor class for Llava.
20
+ """
21
+
22
+ import os
23
+ import json
24
+ from typing import List, Optional, Union, Dict
25
+
26
+ # from ...feature_extraction_utils import BatchFeature
27
+ # from ...image_utils import ImageInput
28
+ # from ...processing_utils import ProcessorMixin
29
+ # from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
30
+ # from ...utils import TensorType
31
+
32
+ from transformers.feature_extraction_sequence_utils import BatchFeature
33
+ from transformers.image_utils import ImageInput
34
+ from transformers.processing_utils import ProcessorMixin
35
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
36
+ from transformers.utils import TensorType
37
+ from transformers.processing_utils import transformers_module
38
+ from transformers.utils.hub import is_remote_url, download_url, cached_file, is_offline_mode
39
+ from transformers.utils import IMAGE_PROCESSOR_NAME
40
+
41
+ from PIL import Image
42
+ import logging
43
+ import torch
44
+ import numpy as np
45
+ logger = logging.getLogger(__name__)
46
+
47
+ class MLlavaProcessor(ProcessorMixin):
48
+ r"""
49
+ Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
50
+
51
+ [`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
52
+ [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
53
+
54
+ Args:
55
+ image_processor ([`CLIPImageProcessor`], *optional*):
56
+ The image processor is a required input.
57
+ tokenizer ([`LlamaTokenizerFast`], *optional*):
58
+ The tokenizer is a required input.
59
+ """
60
+
61
+ attributes = ["image_processor", "tokenizer"]
62
+ image_processor_class = ("CLIPImageProcessor", "SiglipImageProcessor")
63
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast", "PreTrainedTokenizerFast")
64
+
65
+ def __init__(self, image_processor=None, tokenizer=None):
66
+ super().__init__(image_processor, tokenizer)
67
+
68
+ def preprocess_interleaved_images_and_text(
69
+ self,
70
+ text,
71
+ images=None,
72
+ ):
73
+ """
74
+ Args:
75
+ text (`str`, `List[str]`):
76
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
77
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
78
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
79
+ text can contain <image> tokens as the placeholder for the image(s) to be inserted.
80
+ images (`PIL.Image.Image`, `List[PIL.Image.Image]`, `List[List[PIL.Image.Image]]`):
81
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
82
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
83
+ number of channels, H and W are image height and width.
84
+ the number of the images should match the number of <image> tokens in the text.
85
+
86
+ """
87
+ assert text is not None, "text cannot be None."
88
+
89
+ if images is not None:
90
+ if isinstance(images, Image.Image):
91
+ images = [images]
92
+ if isinstance(images, list) and isinstance(images[0], Image.Image):
93
+ if isinstance(text, str):
94
+ images = [images]
95
+ elif isinstance(text, list):
96
+ if len(text) != len(images):
97
+ raise ValueError("Invalid input text. Number of texts does not match number of images.")
98
+ images = [[image] for image in images]
99
+ if isinstance(text, str):
100
+ num_images = len(images[0])
101
+ num_image_tokens = text.count("<image>")
102
+ if num_image_tokens < num_images:
103
+ # prepend empty image tokens to text
104
+ if "USER:" in text:
105
+ text = text.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
106
+ elif "Human:" in text:
107
+ text = text.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
108
+ elif "HUMAN:" in text:
109
+ text = text.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
110
+ else:
111
+ text = "<image>" * (num_images - num_image_tokens) + text
112
+ # logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
113
+ elif num_image_tokens > num_images:
114
+ text = text.split("<image>")
115
+ for i, t in enumerate(text):
116
+ if i < num_images:
117
+ text[i] = t + "<image>"
118
+ text = "".join(text)
119
+ logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
120
+ # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
121
+ texts = [text]
122
+ elif isinstance(text, list):
123
+ if not isinstance(text[0], str):
124
+ raise ValueError("Invalid input text. Each element of text must be a string.")
125
+ for i, t in enumerate(text):
126
+ num_image_tokens = t.count("<image>")
127
+ num_images = len(images[i])
128
+ if num_image_tokens < num_images:
129
+ # prepend empty image tokens to text
130
+ if "USER:" in t:
131
+ t = t.replace("USER:", "USER:" + "<image>" * (num_images - num_image_tokens), 1)
132
+ elif "Human:" in t:
133
+ t = t.replace("Human:", "Human:" + "<image>" * (num_images - num_image_tokens), 1)
134
+ elif "HUMAN:" in t:
135
+ t = t.replace("HUMAN:", "HUMAN:" + "<image>" * (num_images - num_image_tokens), 1)
136
+ else:
137
+ t = "<image>" * (num_images - num_image_tokens) + t
138
+ # logger.warning("Image Tokens <image> are not provided in the text. Automatically prepending them before the text. This might cause model to behave unexpectedly.")
139
+ elif num_image_tokens > num_images:
140
+ t = t.split("<image>")
141
+ for j, s in enumerate(t):
142
+ if j < num_images:
143
+ t[j] = s + "<image>"
144
+ t = "".join(t)
145
+ logger.warning(f"Number of <image> tokens: {num_image_tokens} exceeds number of images: {num_images}. Automatically removing extra tokens at the end of the text.")
146
+ # raise ValueError("Invalid input text. Number of <image> tokens exceeds number of images.")
147
+ text[i] = t
148
+ texts = text
149
+ else:
150
+ raise ValueError("Invalid input text. text must be a string or a list of strings.")
151
+ assert all([t.count("<image>") == len(images_per_text) for t, images_per_text in zip(texts, images)]), "Number of <image> tokens in text does not match number of images."
152
+ # add image denotation in text before each <image> as "(image {i}: <image>)"
153
+ for i, t in enumerate(texts):
154
+ for j in range(len(images[i])):
155
+ t = t.replace("<image>", f"(image {j+1}: <Image><IMAGE></Image>)", 1)
156
+ t = t.replace("<IMAGE>", "<image>")
157
+ texts[i] = t
158
+
159
+ # flatten images
160
+ images = [image for images_per_text in images for image in images_per_text]
161
+ else:
162
+ if isinstance(text, str):
163
+ texts = [text]
164
+ elif isinstance(text, list):
165
+ if not isinstance(text[0], str):
166
+ raise ValueError("Invalid input text. Each element of text must be a string.")
167
+ texts = text
168
+ else:
169
+ raise ValueError("Invalid input text. text must be a string or a list of strings.")
170
+
171
+ return texts, images
172
+
173
+ def __call__(
174
+ self,
175
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
176
+ images: ImageInput = None,
177
+ padding: Union[bool, str, PaddingStrategy] = False,
178
+ truncation: Union[bool, str, TruncationStrategy] = None,
179
+ max_length=None,
180
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
181
+ add_image_ids: bool = True,
182
+ ) -> BatchFeature:
183
+ """
184
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
185
+ and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
186
+ the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
187
+ CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
188
+ of the above two methods for more information.
189
+
190
+ Args:
191
+ text (`str`, `List[str]`, `List[List[str]]`):
192
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
193
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
194
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
195
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
196
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
197
+ tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
198
+ number of channels, H and W are image height and width.
199
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
200
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
201
+ index) among:
202
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
203
+ sequence if provided).
204
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
205
+ acceptable input length for the model if that argument is not provided.
206
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
207
+ lengths).
208
+ max_length (`int`, *optional*):
209
+ Maximum length of the returned list and optionally padding length (see above).
210
+ truncation (`bool`, *optional*):
211
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
212
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
213
+ If set, will return tensors of a particular framework. Acceptable values are:
214
+
215
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
216
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
217
+ - `'np'`: Return NumPy `np.ndarray` objects.
218
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
219
+
220
+ Returns:
221
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
222
+
223
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
224
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
225
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
226
+ `None`).
227
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
228
+ """
229
+ if add_image_ids:
230
+ text, images = self.preprocess_interleaved_images_and_text(text, images)
231
+ if images is not None:
232
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] # [batch_size, num_channels, height, width], e.g. [1, 3, 336, 336]
233
+ else:
234
+ pixel_values = None
235
+ text_inputs = self.tokenizer(
236
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
237
+ )
238
+ # text_inputs:
239
+ # 1. input_ids: [batch_size, sequence_length], e.g. [1, 6]
240
+ # 2. attention_mask: [batch_size, sequence_length], e.g. [1, 6]
241
+
242
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
243
+
244
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
245
+ def batch_decode(self, *args, **kwargs):
246
+ """
247
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
248
+ refer to the docstring of this method for more information.
249
+ """
250
+ return self.tokenizer.batch_decode(*args, **kwargs)
251
+
252
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
253
+ def decode(self, *args, **kwargs):
254
+ """
255
+ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
256
+ the docstring of this method for more information.
257
+ """
258
+ return self.tokenizer.decode(*args, **kwargs)
259
+
260
+ @property
261
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
262
+ def model_input_names(self):
263
+ tokenizer_input_names = self.tokenizer.model_input_names
264
+ image_processor_input_names = self.image_processor.model_input_names
265
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
266
+
267
+ def _right_pad_inputs_with_attention_mask(self, model_inputs: List[Dict]):
268
+ results = {}
269
+ assert len(model_inputs) == 1, "This method only supports a single input, but get {} inputs".format(len(model_inputs))
270
+ for k in model_inputs[0].keys():
271
+ if k == "pixel_values":
272
+ results[k] = [inputs[k] if inputs[k] is not None else None for inputs in model_inputs]
273
+ else:
274
+ results[k] = torch.cat([inputs[k] for inputs in model_inputs], dim=0)
275
+ return results
276
+
277
+ @classmethod
278
+ def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
279
+ args = []
280
+
281
+ cache_dir = kwargs.pop("cache_dir", None)
282
+ force_download = kwargs.pop("force_download", False)
283
+ resume_download = kwargs.pop("resume_download", False)
284
+ proxies = kwargs.pop("proxies", None)
285
+ token = kwargs.pop("token", None)
286
+ local_files_only = kwargs.pop("local_files_only", False)
287
+ revision = kwargs.pop("revision", None)
288
+ subfolder = kwargs.pop("subfolder", "")
289
+
290
+ from_pipeline = kwargs.pop("_from_pipeline", None)
291
+ from_auto_class = kwargs.pop("_from_auto", False)
292
+
293
+ user_agent = {"file_type": "processor", "from_auto_class": from_auto_class}
294
+ if from_pipeline is not None:
295
+ user_agent["using_pipeline"] = from_pipeline
296
+
297
+ if is_offline_mode() and not local_files_only:
298
+ logger.info("Offline mode: forcing local_files_only=True")
299
+ local_files_only = True
300
+
301
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
302
+ is_local = os.path.isdir(pretrained_model_name_or_path)
303
+ if os.path.isdir(pretrained_model_name_or_path):
304
+ processor_file = os.path.join(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME)
305
+ if os.path.isfile(pretrained_model_name_or_path):
306
+ resolved_processor_file = pretrained_model_name_or_path
307
+ is_local = True
308
+ elif is_remote_url(pretrained_model_name_or_path):
309
+ processor_file = pretrained_model_name_or_path
310
+ resolved_processor_file = download_url(pretrained_model_name_or_path)
311
+ else:
312
+ processor_file = IMAGE_PROCESSOR_NAME
313
+ try:
314
+ # Load from local folder or from cache or download from model Hub and cache
315
+ resolved_processor_file = cached_file(
316
+ pretrained_model_name_or_path,
317
+ processor_file,
318
+ cache_dir=cache_dir,
319
+ force_download=force_download,
320
+ proxies=proxies,
321
+ resume_download=resume_download,
322
+ local_files_only=local_files_only,
323
+ token=token,
324
+ user_agent=user_agent,
325
+ revision=revision,
326
+ subfolder=subfolder,
327
+ _raise_exceptions_for_missing_entries=True,
328
+ )
329
+ except EnvironmentError:
330
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
331
+ # the original exception.
332
+ raise
333
+ except Exception:
334
+ # For any other exception, we throw a generic error.
335
+ raise EnvironmentError(
336
+ f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
337
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
338
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
339
+ f" directory containing a {IMAGE_PROCESSOR_NAME} file"
340
+ )
341
+
342
+ # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
343
+ # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
344
+ # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
345
+ # However, for models added in the future, we won't get the expected error if this file is missing.
346
+ if resolved_processor_file is None:
347
+ image_processor_dict = {}
348
+
349
+ try:
350
+ # Load processor dict
351
+ with open(resolved_processor_file, "r", encoding="utf-8") as reader:
352
+ text = reader.read()
353
+ image_processor_dict = json.loads(text)
354
+
355
+ except json.JSONDecodeError:
356
+ raise EnvironmentError(
357
+ f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
358
+ )
359
+
360
+ for attribute_name in cls.attributes:
361
+ class_name = getattr(cls, f"{attribute_name}_class")
362
+ if isinstance(class_name, tuple):
363
+ if attribute_name == "tokenizer":
364
+ classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name)
365
+ use_fast = kwargs.get("use_fast", True)
366
+ if use_fast and classes[1] is not None:
367
+ attribute_class = classes[1]
368
+ else:
369
+ attribute_class = classes[0]
370
+ elif attribute_name == "image_processor":
371
+ image_processor_type = image_processor_dict.get("image_processor_type", None)
372
+ if image_processor_type is not None:
373
+ assert image_processor_type in class_name, f"Invalid image processor type: {image_processor_type}"
374
+ attribute_class = getattr(transformers_module, image_processor_type)
375
+ else:
376
+ attribute_class = getattr(transformers_module, class_name[0])
377
+ else:
378
+ raise ValueError(f"Invalid attribute name: {attribute_name}")
379
+ else:
380
+ attribute_class = getattr(transformers_module, class_name)
381
+
382
+ args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
383
+ return args
384
+
models/mllava/utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script includes codes copied directly from https://huggingface.co/spaces/TIGER-Lab/Mantis
2
+
3
+ import PIL
4
+ import torch
5
+ from .modeling_llava import LlavaForConditionalGeneration
6
+ from .processing_llava import MLlavaProcessor
7
+ # from ..conversation import conv_mllava_v1_mmtag as default_conv
8
+ from ..conversation import conv_mllava_v1 as default_conv, conv_templates
9
+
10
+ from typing import List, Tuple, Union, Tuple
11
+
12
+ def chat_mllava(
13
+ text:str,
14
+ images: List[Union[PIL.Image.Image, str]],
15
+ model:LlavaForConditionalGeneration,
16
+ processor:MLlavaProcessor,
17
+ max_input_length:int=None,
18
+ history:List[dict]=None,
19
+ **kwargs) -> Tuple[str, List[dict]]:
20
+ """
21
+ Chat with the Mllava model
22
+ Args:
23
+ text: str, the text to be sent to the model, where <image> will be the placeholder for the image
24
+ images: List[PIL.Image.Image], the images to be sent to the model, or None
25
+ model: LlavaForConditionalGeneration, the model to be used
26
+ processor: MLlavaProcessor, the processor to be used
27
+ max_input_length: int, the maximum input length
28
+ history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
29
+ kwargs: dict, the generation kwargs
30
+ Returns:
31
+ Tuple[str, List[dict]], the generated text and the history of the conversation
32
+
33
+
34
+ """
35
+ if "llama-3" in model.language_model.name_or_path.lower():
36
+ conv = conv_templates['llama_3']
37
+ terminators = [
38
+ processor.tokenizer.eos_token_id,
39
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
40
+ ]
41
+ else:
42
+ conv = default_conv
43
+ terminators = None
44
+ kwargs["eos_token_id"] = terminators
45
+ conv = conv.copy()
46
+ conv.messages = []
47
+ if history is not None:
48
+ for message in history:
49
+ assert message["role"] in conv.roles
50
+ conv.append_message(message["role"], message["text"])
51
+ if text:
52
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
53
+ conv.append_message(conv.roles[0], text)
54
+ conv.append_message(conv.roles[1], "")
55
+ history.append({"role": conv.roles[0], "text": text})
56
+ history.append({"role": conv.roles[1], "text": ""})
57
+ else:
58
+ if conv.messages[-1][0] == conv.roles[1]:
59
+ assert conv.messages[-1][1] == "", "No user message should be provided"
60
+ else:
61
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
62
+ conv.append_message(conv.roles[0], "")
63
+ history.append({"role": conv.roles[0], "text": ""})
64
+ else:
65
+ history = []
66
+ history.append({"role": conv.roles[0], "text": text})
67
+ history.append({"role": conv.roles[1], "text": ""})
68
+ conv.append_message(conv.roles[0], text)
69
+ conv.append_message(conv.roles[1], "")
70
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
71
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
72
+
73
+ prompt = conv.get_prompt()
74
+ if images:
75
+ for i in range(len(images)):
76
+ if isinstance(images[i], str):
77
+ images[i] = PIL.Image.open(images[i]).convert("RGB")
78
+
79
+ inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
80
+ for k, v in inputs.items():
81
+ if v is not None:
82
+ if isinstance(v, torch.Tensor):
83
+ inputs[k] = v.to(model.device)
84
+ elif isinstance(v, list):
85
+ inputs[k] = [x.to(model.device) for x in v]
86
+ else:
87
+ raise ValueError(f"Invalid input type: {type(v)}")
88
+
89
+
90
+ output_ids = model.generate(**inputs, **kwargs)
91
+ output_ids = output_ids[0]
92
+
93
+ # remove the input tokens
94
+ generated_ids = output_ids[inputs["input_ids"].shape[-1]:]
95
+ generated_text = processor.decode(generated_ids, skip_special_tokens=True)
96
+
97
+ history[-1]["text"] = generated_text
98
+
99
+ return generated_text, history
100
+
101
+
102
+ def prepare_inputs(
103
+ text:str,
104
+ images: List[Union[PIL.Image.Image, str]],
105
+ model: LlavaForConditionalGeneration,
106
+ processor: MLlavaProcessor,
107
+ max_input_length: int=None,
108
+ history: List[dict]=None,
109
+ **kwargs) -> Tuple[str, List[dict]]:
110
+ """
111
+ Chat with the Mllava model
112
+ Args:
113
+ text: str, the text to be sent to the model, where <image> will be the placeholder for the image
114
+ images: List[PIL.Image.Image], the images to be sent to the model, or None
115
+ model: LlavaForConditionalGeneration, the model to be used
116
+ processor: MLlavaProcessor, the processor to be used
117
+ max_input_length: int, the maximum input length
118
+ history: List[dict], list of messages in the conversation as history. Each message is a dictionary {"role": "ASSISTANT/USER", "text": "the message"}. If None, the conversation will start from scratch
119
+ kwargs: dict, the generation kwargs
120
+ Returns:
121
+ Tuple[str, List[dict]], the generated text and the history of the conversation
122
+
123
+
124
+ """
125
+ if "llama-3" in model.language_model.name_or_path.lower():
126
+ conv = conv_templates['llama_3']
127
+ terminators = [
128
+ processor.tokenizer.eos_token_id,
129
+ processor.tokenizer.convert_tokens_to_ids("<|eot_id|>")
130
+ ]
131
+ else:
132
+ conv = default_conv
133
+ terminators = None
134
+ kwargs["eos_token_id"] = terminators
135
+ conv = conv.copy()
136
+ conv.messages = []
137
+ if history is not None:
138
+ for message in history:
139
+ assert message["role"] in conv.roles
140
+ conv.append_message(message["role"], message["text"])
141
+ if text:
142
+ assert conv.messages[-1][0] == conv.roles[1], "The last message in the history should be the assistant, if the given text is not empty"
143
+ conv.append_message(conv.roles[0], text)
144
+ conv.append_message(conv.roles[1], "")
145
+ history.append({"role": conv.roles[0], "text": text})
146
+ history.append({"role": conv.roles[1], "text": ""})
147
+ else:
148
+ if conv.messages[-1][0] == conv.roles[1]:
149
+ assert conv.messages[-1][1] == "", "No user message should be provided"
150
+ else:
151
+ assert conv.messages[-1][0] == conv.roles[0], "The last message in the history should be the user, if the given text is empty"
152
+ conv.append_message(conv.roles[0], "")
153
+ history.append({"role": conv.roles[0], "text": ""})
154
+ else:
155
+ history = []
156
+ history.append({"role": conv.roles[0], "text": text})
157
+ history.append({"role": conv.roles[1], "text": ""})
158
+ conv.append_message(conv.roles[0], text)
159
+ conv.append_message(conv.roles[1], "")
160
+ assert conv.messages[-1][0] == conv.roles[1] and conv.messages[-1][1] == "", "Format check"
161
+ assert history[-1]["role"] == conv.roles[1] and history[-1]["text"] == "", "Format check"
162
+
163
+ prompt = conv.get_prompt()
164
+ if images:
165
+ for i in range(len(images)):
166
+ if isinstance(images[i], str):
167
+ images[i] = PIL.Image.open(images[i])
168
+ images[i] = images[i].convert("RGB")
169
+
170
+ inputs = processor(images=images, text=prompt, return_tensors="pt", truncation=True, max_length=max_input_length)
171
+ # print(processor.tokenizer.decode(inputs["input_ids"][0]))
172
+ # for k, v in inputs.items():
173
+ # if v is not None:
174
+ # if isinstance(v, torch.Tensor):
175
+ # inputs[k] = v.to(model.device, dtype=model.dtype)
176
+ # elif isinstance(v, list):
177
+ # inputs[k] = [x.to(model.device) for x in v]
178
+ # else:
179
+ # raise ValueError(f"Invalid input type: {type(v)}")
180
+ inputs = inputs.to(model.device, model.dtype)
181
+ inputs.update(kwargs)
182
+ return inputs
requirements.txt CHANGED
@@ -1 +1,7 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
1
+ torch
2
+ transformers==4.42.3
3
+ Pillow
4
+ gradio==4.36.1
5
+ spaces
6
+ multiprocess
7
+ accelerate