Upload 17 files
Browse files- config.json +5 -4
- configuration_minicpm.py +3 -2
- generation_config.json +1 -1
- image_processing_minicpmv.py +55 -88
- model.safetensors +2 -2
- modeling_minicpmv.py +99 -83
- modeling_navit_siglip.py +19 -17
- processing_minicpmv.py +57 -52
- resampler.py +226 -150
- tokenization_minicpmv_fast.py +5 -5
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"MiniCPMV"
|
5 |
],
|
@@ -17,7 +17,7 @@
|
|
17 |
"hidden_size": 256,
|
18 |
"image_size": 28,
|
19 |
"initializer_range": 0.02,
|
20 |
-
"intermediate_size":
|
21 |
"max_position_embeddings": 32768,
|
22 |
"max_window_layers": 2,
|
23 |
"model_type": "minicpmv",
|
@@ -37,16 +37,17 @@
|
|
37 |
"sliding_window": null,
|
38 |
"tie_word_embeddings": false,
|
39 |
"torch_dtype": "float32",
|
40 |
-
"transformers_version": "4.
|
41 |
"use_cache": true,
|
42 |
"use_image_id": true,
|
43 |
"use_sliding_window": false,
|
44 |
"version": 2.6,
|
45 |
"vision_batch_size": 16,
|
46 |
"vision_config": {
|
|
|
47 |
"hidden_size": 64,
|
48 |
"image_size": 28,
|
49 |
-
"intermediate_size":
|
50 |
"model_type": "siglip_vision_model",
|
51 |
"num_attention_heads": 2,
|
52 |
"num_hidden_layers": 4,
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "/home/ea/work/my_optimum_intel/optimum-intel/tiny-random-minicpmv-2_6",
|
3 |
"architectures": [
|
4 |
"MiniCPMV"
|
5 |
],
|
|
|
17 |
"hidden_size": 256,
|
18 |
"image_size": 28,
|
19 |
"initializer_range": 0.02,
|
20 |
+
"intermediate_size": 128,
|
21 |
"max_position_embeddings": 32768,
|
22 |
"max_window_layers": 2,
|
23 |
"model_type": "minicpmv",
|
|
|
37 |
"sliding_window": null,
|
38 |
"tie_word_embeddings": false,
|
39 |
"torch_dtype": "float32",
|
40 |
+
"transformers_version": "4.46.1",
|
41 |
"use_cache": true,
|
42 |
"use_image_id": true,
|
43 |
"use_sliding_window": false,
|
44 |
"version": 2.6,
|
45 |
"vision_batch_size": 16,
|
46 |
"vision_config": {
|
47 |
+
"_attn_implementation_autoset": true,
|
48 |
"hidden_size": 64,
|
49 |
"image_size": 28,
|
50 |
+
"intermediate_size": 128,
|
51 |
"model_type": "siglip_vision_model",
|
52 |
"num_attention_heads": 2,
|
53 |
"num_hidden_layers": 4,
|
configuration_minicpm.py
CHANGED
@@ -4,10 +4,12 @@
|
|
4 |
import os
|
5 |
from typing import Union
|
6 |
|
|
|
7 |
from transformers.utils import logging
|
8 |
-
|
9 |
from .modeling_navit_siglip import SiglipVisionConfig
|
10 |
|
|
|
11 |
logger = logging.get_logger(__name__)
|
12 |
|
13 |
|
@@ -44,7 +46,6 @@ class MiniCPMVSliceConfig(PretrainedConfig):
|
|
44 |
return cls.from_dict(config_dict, **kwargs)
|
45 |
|
46 |
|
47 |
-
|
48 |
class MiniCPMVConfig(Qwen2Config):
|
49 |
model_type = "minicpmv"
|
50 |
keys_to_ignore_at_inference = ["past_key_values"]
|
|
|
4 |
import os
|
5 |
from typing import Union
|
6 |
|
7 |
+
from transformers import PretrainedConfig, Qwen2Config
|
8 |
from transformers.utils import logging
|
9 |
+
|
10 |
from .modeling_navit_siglip import SiglipVisionConfig
|
11 |
|
12 |
+
|
13 |
logger = logging.get_logger(__name__)
|
14 |
|
15 |
|
|
|
46 |
return cls.from_dict(config_dict, **kwargs)
|
47 |
|
48 |
|
|
|
49 |
class MiniCPMVConfig(Qwen2Config):
|
50 |
model_type = "minicpmv"
|
51 |
keys_to_ignore_at_inference = ["past_key_values"]
|
generation_config.json
CHANGED
@@ -2,5 +2,5 @@
|
|
2 |
"_from_model_config": true,
|
3 |
"bos_token_id": 151643,
|
4 |
"eos_token_id": 151645,
|
5 |
-
"transformers_version": "4.
|
6 |
}
|
|
|
2 |
"_from_model_config": true,
|
3 |
"bos_token_id": 151643,
|
4 |
"eos_token_id": 151645,
|
5 |
+
"transformers_version": "4.46.1"
|
6 |
}
|
image_processing_minicpmv.py
CHANGED
@@ -1,27 +1,23 @@
|
|
1 |
-
from typing import Optional, Union, Dict, Any, List
|
2 |
-
|
3 |
-
import torch
|
4 |
import math
|
5 |
-
import
|
6 |
-
|
7 |
import numpy as np
|
8 |
import PIL
|
|
|
|
|
|
|
9 |
from PIL import Image
|
10 |
-
|
11 |
-
from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
|
12 |
-
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
13 |
from transformers import AutoImageProcessor
|
|
|
14 |
from transformers.image_transforms import to_channel_dimension_format
|
15 |
from transformers.image_utils import (
|
16 |
-
|
17 |
-
make_list_of_images,
|
18 |
-
valid_images,
|
19 |
-
is_torch_tensor,
|
20 |
-
is_batched,
|
21 |
-
to_numpy_array,
|
22 |
infer_channel_dimension_format,
|
23 |
-
|
|
|
|
|
24 |
)
|
|
|
25 |
|
26 |
|
27 |
def recursive_converter(converter, value):
|
@@ -38,6 +34,7 @@ class MiniCPMVBatchFeature(BatchFeature):
|
|
38 |
r"""
|
39 |
Extend from BatchFeature for supporting various image size
|
40 |
"""
|
|
|
41 |
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
|
42 |
super().__init__(data)
|
43 |
self.convert_to_tensors(tensor_type=tensor_type)
|
@@ -45,7 +42,7 @@ class MiniCPMVBatchFeature(BatchFeature):
|
|
45 |
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
46 |
if tensor_type is None:
|
47 |
return self
|
48 |
-
|
49 |
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
|
50 |
|
51 |
def converter(value):
|
@@ -61,11 +58,10 @@ class MiniCPMVBatchFeature(BatchFeature):
|
|
61 |
"with 'padding=True' to have batched tensors with the same length."
|
62 |
)
|
63 |
|
64 |
-
|
65 |
for key, value in self.items():
|
66 |
self[key] = recursive_converter(converter, value)
|
67 |
return self
|
68 |
-
|
69 |
def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
|
70 |
requires_backends(self, ["torch"])
|
71 |
import torch
|
@@ -104,12 +100,7 @@ class MiniCPMVBatchFeature(BatchFeature):
|
|
104 |
class MiniCPMVImageProcessor(BaseImageProcessor):
|
105 |
model_input_names = ["pixel_values"]
|
106 |
|
107 |
-
def __init__(
|
108 |
-
self,
|
109 |
-
max_slice_nums=9,
|
110 |
-
scale_resolution=448,
|
111 |
-
patch_size=14,
|
112 |
-
**kwargs):
|
113 |
super().__init__(**kwargs)
|
114 |
self.max_slice_nums = max_slice_nums
|
115 |
self.scale_resolution = scale_resolution
|
@@ -131,14 +122,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
131 |
def ensure_divide(self, length, patch_size):
|
132 |
return max(round(length / patch_size) * patch_size, patch_size)
|
133 |
|
134 |
-
def find_best_resize(self,
|
135 |
-
original_size,
|
136 |
-
scale_resolution,
|
137 |
-
patch_size,
|
138 |
-
allow_upscale=False):
|
139 |
width, height = original_size
|
140 |
-
if (width * height >
|
141 |
-
scale_resolution * scale_resolution) or allow_upscale:
|
142 |
r = width / height
|
143 |
height = int(scale_resolution / math.sqrt(r))
|
144 |
width = int(height * r)
|
@@ -146,12 +132,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
146 |
best_height = self.ensure_divide(height, patch_size)
|
147 |
return (best_width, best_height)
|
148 |
|
149 |
-
def get_refine_size(self,
|
150 |
-
original_size,
|
151 |
-
grid,
|
152 |
-
scale_resolution,
|
153 |
-
patch_size,
|
154 |
-
allow_upscale=False):
|
155 |
width, height = original_size
|
156 |
grid_x, grid_y = grid
|
157 |
|
@@ -161,10 +142,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
161 |
grid_width = refine_width / grid_x
|
162 |
grid_height = refine_height / grid_y
|
163 |
|
164 |
-
best_grid_size = self.find_best_resize(
|
165 |
-
|
166 |
-
|
167 |
-
allow_upscale=allow_upscale)
|
168 |
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
|
169 |
return refine_size
|
170 |
|
@@ -182,9 +162,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
182 |
patches.append(images)
|
183 |
return patches
|
184 |
|
185 |
-
def slice_image(
|
186 |
-
self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
|
187 |
-
):
|
188 |
original_size = image.size
|
189 |
source_image = None
|
190 |
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
|
@@ -192,9 +170,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
192 |
|
193 |
if best_grid is None:
|
194 |
# dont need to slice, upsample
|
195 |
-
best_size = self.find_best_resize(
|
196 |
-
original_size, scale_resolution, patch_size, allow_upscale=True
|
197 |
-
)
|
198 |
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
|
199 |
else:
|
200 |
# source image, down-sampling and ensure divided by patch_size
|
@@ -212,9 +188,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
212 |
if grid is None:
|
213 |
return ""
|
214 |
slice_image_placeholder = (
|
215 |
-
self.slice_start_token
|
216 |
-
+ self.unk_token * self.image_feature_size
|
217 |
-
+ self.slice_end_token
|
218 |
)
|
219 |
|
220 |
cols = grid[0]
|
@@ -225,13 +199,13 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
225 |
for j in range(cols):
|
226 |
lines.append(slice_image_placeholder)
|
227 |
slices.append("".join(lines))
|
228 |
-
|
229 |
slice_placeholder = "\n".join(slices)
|
230 |
return slice_placeholder
|
231 |
|
232 |
def get_image_id_placeholder(self, idx=0):
|
233 |
return f"{self.im_id_start}{idx}{self.im_id_end}"
|
234 |
-
|
235 |
def get_sliced_images(self, image, max_slice_nums=None):
|
236 |
slice_images = []
|
237 |
|
@@ -239,12 +213,9 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
239 |
return [image]
|
240 |
|
241 |
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
|
242 |
-
assert max_slice_nums > 0
|
243 |
source_image, patches, sliced_grid = self.slice_image(
|
244 |
-
image,
|
245 |
-
max_slice_nums, # default: 9
|
246 |
-
self.scale_resolution, # default: 448
|
247 |
-
self.patch_size # default: 14
|
248 |
)
|
249 |
|
250 |
slice_images.append(source_image)
|
@@ -266,7 +237,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
266 |
if i == 1 or i > max_slice_nums:
|
267 |
continue
|
268 |
candidate_split_grids_nums.append(i)
|
269 |
-
|
270 |
candidate_grids = []
|
271 |
for split_grids_nums in candidate_split_grids_nums:
|
272 |
m = 1
|
@@ -282,19 +253,15 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
282 |
if error < min_error:
|
283 |
best_grid = grid
|
284 |
min_error = error
|
285 |
-
|
286 |
return best_grid
|
287 |
-
|
288 |
def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
|
289 |
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
|
290 |
-
assert max_slice_nums > 0
|
291 |
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
|
292 |
|
293 |
-
image_placeholder =
|
294 |
-
self.im_start_token
|
295 |
-
+ self.unk_token * self.image_feature_size
|
296 |
-
+ self.im_end_token
|
297 |
-
)
|
298 |
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
|
299 |
if use_image_id:
|
300 |
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
|
@@ -304,7 +271,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
304 |
if self.slice_mode:
|
305 |
final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
|
306 |
return final_placeholder
|
307 |
-
|
308 |
def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
|
309 |
"""
|
310 |
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
@@ -343,24 +310,20 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
343 |
"""
|
344 |
image = torch.from_numpy(image)
|
345 |
patch_size = self.patch_size
|
346 |
-
patches = torch.nn.functional.unfold(
|
347 |
-
image,
|
348 |
-
(patch_size, patch_size),
|
349 |
-
stride=(patch_size, patch_size)
|
350 |
-
)
|
351 |
|
352 |
patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
|
353 |
patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
|
354 |
return patches.numpy()
|
355 |
|
356 |
def preprocess(
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
if isinstance(images, Image.Image):
|
365 |
images_list = [[images]]
|
366 |
elif isinstance(images[0], Image.Image):
|
@@ -371,19 +334,19 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
371 |
new_images_list = []
|
372 |
image_sizes_list = []
|
373 |
tgt_sizes_list = []
|
374 |
-
|
375 |
for _images in images_list:
|
376 |
if _images is None or len(_images) == 0:
|
377 |
new_images_list.append([])
|
378 |
image_sizes_list.append([])
|
379 |
tgt_sizes_list.append([])
|
380 |
-
continue
|
381 |
if not valid_images(_images):
|
382 |
raise ValueError(
|
383 |
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
384 |
"torch.Tensor, tf.Tensor or jax.ndarray."
|
385 |
)
|
386 |
-
|
387 |
_images = [self.to_pil_image(image).convert("RGB") for image in _images]
|
388 |
input_data_format = infer_channel_dimension_format(np.array(_images[0]))
|
389 |
|
@@ -395,24 +358,28 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
395 |
image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
|
396 |
image_patches = [
|
397 |
self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
|
398 |
-
|
399 |
]
|
400 |
image_patches = [
|
401 |
-
to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
402 |
-
|
403 |
]
|
404 |
for slice_image in image_patches:
|
405 |
new_images.append(self.reshape_by_patch(slice_image))
|
406 |
-
tgt_sizes.append(
|
|
|
|
|
407 |
|
408 |
if tgt_sizes:
|
409 |
tgt_sizes = np.vstack(tgt_sizes)
|
410 |
-
|
411 |
new_images_list.append(new_images)
|
412 |
image_sizes_list.append(image_sizes)
|
413 |
tgt_sizes_list.append(tgt_sizes)
|
414 |
return MiniCPMVBatchFeature(
|
415 |
-
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
|
|
|
416 |
)
|
417 |
|
|
|
418 |
AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
+
from typing import Any, Dict, List, Optional, Union
|
3 |
+
|
4 |
import numpy as np
|
5 |
import PIL
|
6 |
+
import PIL.Image
|
7 |
+
import PIL.ImageSequence
|
8 |
+
import torch
|
9 |
from PIL import Image
|
|
|
|
|
|
|
10 |
from transformers import AutoImageProcessor
|
11 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
12 |
from transformers.image_transforms import to_channel_dimension_format
|
13 |
from transformers.image_utils import (
|
14 |
+
ChannelDimension,
|
|
|
|
|
|
|
|
|
|
|
15 |
infer_channel_dimension_format,
|
16 |
+
is_torch_tensor,
|
17 |
+
to_numpy_array,
|
18 |
+
valid_images,
|
19 |
)
|
20 |
+
from transformers.utils import TensorType, is_torch_device, is_torch_dtype, requires_backends
|
21 |
|
22 |
|
23 |
def recursive_converter(converter, value):
|
|
|
34 |
r"""
|
35 |
Extend from BatchFeature for supporting various image size
|
36 |
"""
|
37 |
+
|
38 |
def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
|
39 |
super().__init__(data)
|
40 |
self.convert_to_tensors(tensor_type=tensor_type)
|
|
|
42 |
def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
|
43 |
if tensor_type is None:
|
44 |
return self
|
45 |
+
|
46 |
is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
|
47 |
|
48 |
def converter(value):
|
|
|
58 |
"with 'padding=True' to have batched tensors with the same length."
|
59 |
)
|
60 |
|
|
|
61 |
for key, value in self.items():
|
62 |
self[key] = recursive_converter(converter, value)
|
63 |
return self
|
64 |
+
|
65 |
def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
|
66 |
requires_backends(self, ["torch"])
|
67 |
import torch
|
|
|
100 |
class MiniCPMVImageProcessor(BaseImageProcessor):
|
101 |
model_input_names = ["pixel_values"]
|
102 |
|
103 |
+
def __init__(self, max_slice_nums=9, scale_resolution=448, patch_size=14, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
104 |
super().__init__(**kwargs)
|
105 |
self.max_slice_nums = max_slice_nums
|
106 |
self.scale_resolution = scale_resolution
|
|
|
122 |
def ensure_divide(self, length, patch_size):
|
123 |
return max(round(length / patch_size) * patch_size, patch_size)
|
124 |
|
125 |
+
def find_best_resize(self, original_size, scale_resolution, patch_size, allow_upscale=False):
|
|
|
|
|
|
|
|
|
126 |
width, height = original_size
|
127 |
+
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
|
|
|
128 |
r = width / height
|
129 |
height = int(scale_resolution / math.sqrt(r))
|
130 |
width = int(height * r)
|
|
|
132 |
best_height = self.ensure_divide(height, patch_size)
|
133 |
return (best_width, best_height)
|
134 |
|
135 |
+
def get_refine_size(self, original_size, grid, scale_resolution, patch_size, allow_upscale=False):
|
|
|
|
|
|
|
|
|
|
|
136 |
width, height = original_size
|
137 |
grid_x, grid_y = grid
|
138 |
|
|
|
142 |
grid_width = refine_width / grid_x
|
143 |
grid_height = refine_height / grid_y
|
144 |
|
145 |
+
best_grid_size = self.find_best_resize(
|
146 |
+
(grid_width, grid_height), scale_resolution, patch_size, allow_upscale=allow_upscale
|
147 |
+
)
|
|
|
148 |
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
|
149 |
return refine_size
|
150 |
|
|
|
162 |
patches.append(images)
|
163 |
return patches
|
164 |
|
165 |
+
def slice_image(self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False):
|
|
|
|
|
166 |
original_size = image.size
|
167 |
source_image = None
|
168 |
best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split)
|
|
|
170 |
|
171 |
if best_grid is None:
|
172 |
# dont need to slice, upsample
|
173 |
+
best_size = self.find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=True)
|
|
|
|
|
174 |
source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
|
175 |
else:
|
176 |
# source image, down-sampling and ensure divided by patch_size
|
|
|
188 |
if grid is None:
|
189 |
return ""
|
190 |
slice_image_placeholder = (
|
191 |
+
self.slice_start_token + self.unk_token * self.image_feature_size + self.slice_end_token
|
|
|
|
|
192 |
)
|
193 |
|
194 |
cols = grid[0]
|
|
|
199 |
for j in range(cols):
|
200 |
lines.append(slice_image_placeholder)
|
201 |
slices.append("".join(lines))
|
202 |
+
|
203 |
slice_placeholder = "\n".join(slices)
|
204 |
return slice_placeholder
|
205 |
|
206 |
def get_image_id_placeholder(self, idx=0):
|
207 |
return f"{self.im_id_start}{idx}{self.im_id_end}"
|
208 |
+
|
209 |
def get_sliced_images(self, image, max_slice_nums=None):
|
210 |
slice_images = []
|
211 |
|
|
|
213 |
return [image]
|
214 |
|
215 |
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
|
216 |
+
assert max_slice_nums > 0
|
217 |
source_image, patches, sliced_grid = self.slice_image(
|
218 |
+
image, max_slice_nums, self.scale_resolution, self.patch_size # default: 9 # default: 448 # default: 14
|
|
|
|
|
|
|
219 |
)
|
220 |
|
221 |
slice_images.append(source_image)
|
|
|
237 |
if i == 1 or i > max_slice_nums:
|
238 |
continue
|
239 |
candidate_split_grids_nums.append(i)
|
240 |
+
|
241 |
candidate_grids = []
|
242 |
for split_grids_nums in candidate_split_grids_nums:
|
243 |
m = 1
|
|
|
253 |
if error < min_error:
|
254 |
best_grid = grid
|
255 |
min_error = error
|
256 |
+
|
257 |
return best_grid
|
258 |
+
|
259 |
def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None):
|
260 |
max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums)
|
261 |
+
assert max_slice_nums > 0
|
262 |
grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums)
|
263 |
|
264 |
+
image_placeholder = self.im_start_token + self.unk_token * self.image_feature_size + self.im_end_token
|
|
|
|
|
|
|
|
|
265 |
use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id)
|
266 |
if use_image_id:
|
267 |
final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder
|
|
|
271 |
if self.slice_mode:
|
272 |
final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid)
|
273 |
return final_placeholder
|
274 |
+
|
275 |
def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
|
276 |
"""
|
277 |
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
|
|
|
310 |
"""
|
311 |
image = torch.from_numpy(image)
|
312 |
patch_size = self.patch_size
|
313 |
+
patches = torch.nn.functional.unfold(image, (patch_size, patch_size), stride=(patch_size, patch_size))
|
|
|
|
|
|
|
|
|
314 |
|
315 |
patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
|
316 |
patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
|
317 |
return patches.numpy()
|
318 |
|
319 |
def preprocess(
|
320 |
+
self,
|
321 |
+
images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
|
322 |
+
do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
|
323 |
+
max_slice_nums: int = None,
|
324 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
325 |
+
**kwargs,
|
326 |
+
) -> MiniCPMVBatchFeature:
|
327 |
if isinstance(images, Image.Image):
|
328 |
images_list = [[images]]
|
329 |
elif isinstance(images[0], Image.Image):
|
|
|
334 |
new_images_list = []
|
335 |
image_sizes_list = []
|
336 |
tgt_sizes_list = []
|
337 |
+
|
338 |
for _images in images_list:
|
339 |
if _images is None or len(_images) == 0:
|
340 |
new_images_list.append([])
|
341 |
image_sizes_list.append([])
|
342 |
tgt_sizes_list.append([])
|
343 |
+
continue
|
344 |
if not valid_images(_images):
|
345 |
raise ValueError(
|
346 |
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
347 |
"torch.Tensor, tf.Tensor or jax.ndarray."
|
348 |
)
|
349 |
+
|
350 |
_images = [self.to_pil_image(image).convert("RGB") for image in _images]
|
351 |
input_data_format = infer_channel_dimension_format(np.array(_images[0]))
|
352 |
|
|
|
358 |
image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
|
359 |
image_patches = [
|
360 |
self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
|
361 |
+
for image in image_patches
|
362 |
]
|
363 |
image_patches = [
|
364 |
+
to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
|
365 |
+
for image in image_patches
|
366 |
]
|
367 |
for slice_image in image_patches:
|
368 |
new_images.append(self.reshape_by_patch(slice_image))
|
369 |
+
tgt_sizes.append(
|
370 |
+
np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
|
371 |
+
)
|
372 |
|
373 |
if tgt_sizes:
|
374 |
tgt_sizes = np.vstack(tgt_sizes)
|
375 |
+
|
376 |
new_images_list.append(new_images)
|
377 |
image_sizes_list.append(image_sizes)
|
378 |
tgt_sizes_list.append(tgt_sizes)
|
379 |
return MiniCPMVBatchFeature(
|
380 |
+
data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
|
381 |
+
tensor_type=return_tensors,
|
382 |
)
|
383 |
|
384 |
+
|
385 |
AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5a13c2a624f4445809755648b73465369b127bf9f4c7a6a87ccf0c7498039149
|
3 |
+
size 315498808
|
modeling_minicpmv.py
CHANGED
@@ -1,20 +1,17 @@
|
|
1 |
-
import math
|
2 |
-
from typing import List, Optional
|
3 |
import json
|
4 |
-
import
|
5 |
-
import torchvision
|
6 |
-
|
7 |
-
from threading import Thread
|
8 |
from copy import deepcopy
|
|
|
|
|
|
|
9 |
from PIL import Image
|
10 |
-
from transformers import AutoProcessor,
|
11 |
|
12 |
from .configuration_minicpm import MiniCPMVConfig
|
13 |
from .modeling_navit_siglip import SiglipVisionTransformer
|
14 |
from .resampler import Resampler
|
15 |
|
16 |
|
17 |
-
|
18 |
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
|
19 |
config_class = MiniCPMVConfig
|
20 |
|
@@ -29,21 +26,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
29 |
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
30 |
self.processor = None
|
31 |
|
32 |
-
self.terminators = [
|
33 |
|
34 |
def init_vision_module(self):
|
35 |
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
|
36 |
-
if self.config._attn_implementation ==
|
37 |
-
self.config.vision_config._attn_implementation =
|
38 |
else:
|
39 |
# not suport sdpa
|
40 |
-
self.config.vision_config._attn_implementation =
|
41 |
model = SiglipVisionTransformer(self.config.vision_config)
|
42 |
if self.config.drop_vision_last_layer:
|
43 |
model.encoder.layers = model.encoder.layers[:-1]
|
44 |
|
45 |
-
setattr(model,
|
46 |
-
setattr(model,
|
47 |
|
48 |
return model
|
49 |
|
@@ -53,7 +50,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
53 |
embed_dim=embed_dim,
|
54 |
num_heads=embed_dim // 128,
|
55 |
kv_dim=vision_dim,
|
56 |
-
adaptive=True
|
57 |
)
|
58 |
|
59 |
def get_input_embeddings(self):
|
@@ -75,11 +72,11 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
75 |
return self.llm
|
76 |
|
77 |
def get_vllm_embedding(self, data):
|
78 |
-
if
|
79 |
dtype = self.llm.model.embed_tokens.weight.dtype
|
80 |
device = self.llm.model.embed_tokens.weight.device
|
81 |
-
tgt_sizes = data[
|
82 |
-
pixel_values_list = data[
|
83 |
vision_hidden_states = []
|
84 |
all_pixel_values = []
|
85 |
img_cnt = []
|
@@ -94,14 +91,15 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
94 |
|
95 |
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
|
96 |
|
97 |
-
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
98 |
-
|
|
|
99 |
B, L, _ = all_pixel_values.shape
|
100 |
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
101 |
|
102 |
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
|
103 |
for i in range(B):
|
104 |
-
patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
105 |
|
106 |
vision_batch_size = self.config.vision_batch_size
|
107 |
all_pixel_values = all_pixel_values.type(dtype)
|
@@ -110,28 +108,33 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
110 |
for i in range(0, B, vision_batch_size):
|
111 |
start_idx = i
|
112 |
end_idx = i + vision_batch_size
|
113 |
-
tmp_hs = self.vpm(
|
|
|
|
|
|
|
|
|
114 |
hs.append(tmp_hs)
|
115 |
vision_embedding = torch.cat(hs, dim=0)
|
116 |
else:
|
117 |
-
vision_embedding = self.vpm(
|
|
|
|
|
118 |
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
119 |
|
120 |
start = 0
|
121 |
for pixel_values in pixel_values_list:
|
122 |
img_cnt = len(pixel_values)
|
123 |
if img_cnt > 0:
|
124 |
-
vision_hidden_states.append(vision_embedding[start: start + img_cnt])
|
125 |
start += img_cnt
|
126 |
else:
|
127 |
vision_hidden_states.append([])
|
128 |
-
else:
|
129 |
if self.training:
|
130 |
-
dummy_image = torch.zeros(
|
131 |
-
|
132 |
-
|
133 |
-
)
|
134 |
-
tgt_sizes = torch.Tensor([[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]).type(torch.int32)
|
135 |
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
|
136 |
else:
|
137 |
dummy_feature = []
|
@@ -139,29 +142,33 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
139 |
vision_hidden_states.append(dummy_feature)
|
140 |
|
141 |
else:
|
142 |
-
vision_hidden_states = data[
|
143 |
|
144 |
-
if hasattr(self.llm.config,
|
145 |
-
vllm_embedding = self.llm.model.embed_tokens(data[
|
146 |
else:
|
147 |
-
vllm_embedding = self.llm.model.embed_tokens(data[
|
148 |
|
149 |
-
vision_hidden_states = [
|
150 |
-
i, torch.Tensor) else i for i in vision_hidden_states
|
|
|
151 |
|
152 |
-
bs = len(data[
|
153 |
for i in range(bs):
|
154 |
cur_vs_hs = vision_hidden_states[i]
|
155 |
if len(cur_vs_hs) > 0:
|
156 |
cur_vllm_emb = vllm_embedding[i]
|
157 |
-
cur_image_bound = data[
|
158 |
if len(cur_image_bound) > 0:
|
159 |
image_indices = torch.stack(
|
160 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
161 |
).to(vllm_embedding.device)
|
162 |
|
163 |
-
cur_vllm_emb.scatter_(
|
164 |
-
|
|
|
|
|
|
|
165 |
elif self.training:
|
166 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
167 |
|
@@ -173,13 +180,8 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
173 |
if position_ids.dtype != torch.int64:
|
174 |
position_ids = position_ids.long()
|
175 |
|
176 |
-
return self.llm(
|
177 |
-
|
178 |
-
position_ids=position_ids,
|
179 |
-
inputs_embeds=vllm_embedding,
|
180 |
-
**kwargs
|
181 |
-
)
|
182 |
-
|
183 |
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
|
184 |
terminators = None
|
185 |
if tokenizer is not None:
|
@@ -187,10 +189,10 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
187 |
kwargs.pop("image_sizes")
|
188 |
output = self.llm.generate(
|
189 |
inputs_embeds=inputs_embeds,
|
190 |
-
#pad_token_id=0,
|
191 |
eos_token_id=terminators,
|
192 |
attention_mask=attention_mask,
|
193 |
-
**kwargs
|
194 |
)
|
195 |
if decode_text:
|
196 |
return self._decode_text(output, tokenizer)
|
@@ -200,16 +202,16 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
200 |
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
201 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
202 |
generation_kwargs = {
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
}
|
208 |
generation_kwargs.update(kwargs)
|
209 |
|
210 |
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
|
211 |
thread.start()
|
212 |
-
|
213 |
return streamer
|
214 |
|
215 |
def _decode_text(self, result_ids, tokenizer):
|
@@ -236,7 +238,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
236 |
return_vision_hidden_states=False,
|
237 |
stream=False,
|
238 |
decode_text=False,
|
239 |
-
**kwargs
|
240 |
):
|
241 |
assert input_ids is not None
|
242 |
assert len(input_ids) == len(pixel_values)
|
@@ -248,7 +250,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
248 |
|
249 |
if vision_hidden_states is None:
|
250 |
model_inputs["pixel_values"] = pixel_values
|
251 |
-
model_inputs[
|
252 |
else:
|
253 |
model_inputs["vision_hidden_states"] = vision_hidden_states
|
254 |
|
@@ -261,11 +263,13 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
261 |
if stream:
|
262 |
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
263 |
else:
|
264 |
-
result = self._decode(
|
|
|
|
|
265 |
|
266 |
if return_vision_hidden_states:
|
267 |
return result, vision_hidden_states
|
268 |
-
|
269 |
return result
|
270 |
|
271 |
def chat(
|
@@ -279,11 +283,11 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
279 |
min_new_tokens=0,
|
280 |
sampling=True,
|
281 |
max_inp_length=8192,
|
282 |
-
system_prompt=
|
283 |
stream=False,
|
284 |
max_slice_nums=None,
|
285 |
use_image_id=None,
|
286 |
-
**kwargs
|
287 |
):
|
288 |
if isinstance(msgs[0], list):
|
289 |
batched = True
|
@@ -291,7 +295,7 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
291 |
batched = False
|
292 |
msgs_list = msgs
|
293 |
images_list = image
|
294 |
-
|
295 |
if batched is False:
|
296 |
images_list, msgs_list = [images_list], [msgs_list]
|
297 |
else:
|
@@ -303,12 +307,22 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
303 |
if self.processor is None:
|
304 |
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
305 |
processor = self.processor
|
306 |
-
|
307 |
-
assert
|
308 |
-
|
309 |
-
|
310 |
-
assert
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
prompts_lists = []
|
314 |
input_images_lists = []
|
@@ -342,19 +356,21 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
342 |
msg["content"] = "\n".join(cur_msgs)
|
343 |
|
344 |
if system_prompt:
|
345 |
-
sys_msg = {
|
346 |
-
copy_msgs = [sys_msg] + copy_msgs
|
347 |
|
348 |
-
prompts_lists.append(
|
|
|
|
|
349 |
input_images_lists.append(images)
|
350 |
|
351 |
inputs = processor(
|
352 |
-
prompts_lists,
|
353 |
-
input_images_lists,
|
354 |
max_slice_nums=max_slice_nums,
|
355 |
use_image_id=use_image_id,
|
356 |
-
return_tensors="pt",
|
357 |
-
max_length=max_inp_length
|
358 |
).to(self.device)
|
359 |
|
360 |
if sampling:
|
@@ -363,20 +379,18 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
363 |
"top_k": 100,
|
364 |
"temperature": 0.7,
|
365 |
"do_sample": True,
|
366 |
-
"repetition_penalty": 1.05
|
367 |
}
|
368 |
else:
|
369 |
generation_config = {
|
370 |
"num_beams": 3,
|
371 |
"repetition_penalty": 1.2,
|
372 |
}
|
373 |
-
|
374 |
if min_new_tokens > 0:
|
375 |
-
generation_config[
|
376 |
|
377 |
-
generation_config.update(
|
378 |
-
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
|
379 |
-
)
|
380 |
|
381 |
inputs.pop("image_sizes")
|
382 |
with torch.inference_mode():
|
@@ -387,15 +401,17 @@ class MiniCPMV(MiniCPMVPreTrainedModel):
|
|
387 |
vision_hidden_states=vision_hidden_states,
|
388 |
stream=stream,
|
389 |
decode_text=True,
|
390 |
-
**generation_config
|
391 |
)
|
392 |
-
|
393 |
if stream:
|
|
|
394 |
def stream_gen():
|
395 |
for text in res:
|
396 |
for term in self.terminators:
|
397 |
-
text = text.replace(term,
|
398 |
yield text
|
|
|
399 |
return stream_gen()
|
400 |
|
401 |
else:
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import math
|
|
|
|
|
|
|
3 |
from copy import deepcopy
|
4 |
+
from threading import Thread
|
5 |
+
|
6 |
+
import torch
|
7 |
from PIL import Image
|
8 |
+
from transformers import AutoProcessor, Qwen2ForCausalLM, Qwen2PreTrainedModel, TextIteratorStreamer
|
9 |
|
10 |
from .configuration_minicpm import MiniCPMVConfig
|
11 |
from .modeling_navit_siglip import SiglipVisionTransformer
|
12 |
from .resampler import Resampler
|
13 |
|
14 |
|
|
|
15 |
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
|
16 |
config_class = MiniCPMVConfig
|
17 |
|
|
|
26 |
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
27 |
self.processor = None
|
28 |
|
29 |
+
self.terminators = ["<|im_end|>", "<|endoftext|>"]
|
30 |
|
31 |
def init_vision_module(self):
|
32 |
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
|
33 |
+
if self.config._attn_implementation == "flash_attention_2":
|
34 |
+
self.config.vision_config._attn_implementation = "flash_attention_2"
|
35 |
else:
|
36 |
# not suport sdpa
|
37 |
+
self.config.vision_config._attn_implementation = "eager"
|
38 |
model = SiglipVisionTransformer(self.config.vision_config)
|
39 |
if self.config.drop_vision_last_layer:
|
40 |
model.encoder.layers = model.encoder.layers[:-1]
|
41 |
|
42 |
+
setattr(model, "embed_dim", model.embeddings.embed_dim)
|
43 |
+
setattr(model, "patch_size", model.embeddings.patch_size)
|
44 |
|
45 |
return model
|
46 |
|
|
|
50 |
embed_dim=embed_dim,
|
51 |
num_heads=embed_dim // 128,
|
52 |
kv_dim=vision_dim,
|
53 |
+
adaptive=True,
|
54 |
)
|
55 |
|
56 |
def get_input_embeddings(self):
|
|
|
72 |
return self.llm
|
73 |
|
74 |
def get_vllm_embedding(self, data):
|
75 |
+
if "vision_hidden_states" not in data:
|
76 |
dtype = self.llm.model.embed_tokens.weight.dtype
|
77 |
device = self.llm.model.embed_tokens.weight.device
|
78 |
+
tgt_sizes = data["tgt_sizes"]
|
79 |
+
pixel_values_list = data["pixel_values"]
|
80 |
vision_hidden_states = []
|
81 |
all_pixel_values = []
|
82 |
img_cnt = []
|
|
|
91 |
|
92 |
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
|
93 |
|
94 |
+
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
|
95 |
+
all_pixel_values, batch_first=True, padding_value=0.0
|
96 |
+
)
|
97 |
B, L, _ = all_pixel_values.shape
|
98 |
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
|
99 |
|
100 |
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
|
101 |
for i in range(B):
|
102 |
+
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
|
103 |
|
104 |
vision_batch_size = self.config.vision_batch_size
|
105 |
all_pixel_values = all_pixel_values.type(dtype)
|
|
|
108 |
for i in range(0, B, vision_batch_size):
|
109 |
start_idx = i
|
110 |
end_idx = i + vision_batch_size
|
111 |
+
tmp_hs = self.vpm(
|
112 |
+
all_pixel_values[start_idx:end_idx],
|
113 |
+
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
|
114 |
+
tgt_sizes=tgt_sizes[start_idx:end_idx],
|
115 |
+
).last_hidden_state
|
116 |
hs.append(tmp_hs)
|
117 |
vision_embedding = torch.cat(hs, dim=0)
|
118 |
else:
|
119 |
+
vision_embedding = self.vpm(
|
120 |
+
all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
|
121 |
+
).last_hidden_state
|
122 |
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
|
123 |
|
124 |
start = 0
|
125 |
for pixel_values in pixel_values_list:
|
126 |
img_cnt = len(pixel_values)
|
127 |
if img_cnt > 0:
|
128 |
+
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
|
129 |
start += img_cnt
|
130 |
else:
|
131 |
vision_hidden_states.append([])
|
132 |
+
else: # no image
|
133 |
if self.training:
|
134 |
+
dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
|
135 |
+
tgt_sizes = torch.Tensor(
|
136 |
+
[[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
|
137 |
+
).type(torch.int32)
|
|
|
138 |
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
|
139 |
else:
|
140 |
dummy_feature = []
|
|
|
142 |
vision_hidden_states.append(dummy_feature)
|
143 |
|
144 |
else:
|
145 |
+
vision_hidden_states = data["vision_hidden_states"]
|
146 |
|
147 |
+
if hasattr(self.llm.config, "scale_emb"):
|
148 |
+
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
|
149 |
else:
|
150 |
+
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
|
151 |
|
152 |
+
vision_hidden_states = [
|
153 |
+
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
|
154 |
+
]
|
155 |
|
156 |
+
bs = len(data["input_ids"])
|
157 |
for i in range(bs):
|
158 |
cur_vs_hs = vision_hidden_states[i]
|
159 |
if len(cur_vs_hs) > 0:
|
160 |
cur_vllm_emb = vllm_embedding[i]
|
161 |
+
cur_image_bound = data["image_bound"][i]
|
162 |
if len(cur_image_bound) > 0:
|
163 |
image_indices = torch.stack(
|
164 |
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
|
165 |
).to(vllm_embedding.device)
|
166 |
|
167 |
+
cur_vllm_emb.scatter_(
|
168 |
+
0,
|
169 |
+
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
|
170 |
+
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
|
171 |
+
)
|
172 |
elif self.training:
|
173 |
cur_vllm_emb += cur_vs_hs[0].mean() * 0
|
174 |
|
|
|
180 |
if position_ids.dtype != torch.int64:
|
181 |
position_ids = position_ids.long()
|
182 |
|
183 |
+
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
|
184 |
+
|
|
|
|
|
|
|
|
|
|
|
185 |
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
|
186 |
terminators = None
|
187 |
if tokenizer is not None:
|
|
|
189 |
kwargs.pop("image_sizes")
|
190 |
output = self.llm.generate(
|
191 |
inputs_embeds=inputs_embeds,
|
192 |
+
# pad_token_id=0,
|
193 |
eos_token_id=terminators,
|
194 |
attention_mask=attention_mask,
|
195 |
+
**kwargs,
|
196 |
)
|
197 |
if decode_text:
|
198 |
return self._decode_text(output, tokenizer)
|
|
|
202 |
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
|
203 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
204 |
generation_kwargs = {
|
205 |
+
"inputs_embeds": inputs_embeds,
|
206 |
+
"pad_token_id": 0,
|
207 |
+
"eos_token_id": terminators,
|
208 |
+
"streamer": streamer,
|
209 |
}
|
210 |
generation_kwargs.update(kwargs)
|
211 |
|
212 |
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
|
213 |
thread.start()
|
214 |
+
|
215 |
return streamer
|
216 |
|
217 |
def _decode_text(self, result_ids, tokenizer):
|
|
|
238 |
return_vision_hidden_states=False,
|
239 |
stream=False,
|
240 |
decode_text=False,
|
241 |
+
**kwargs,
|
242 |
):
|
243 |
assert input_ids is not None
|
244 |
assert len(input_ids) == len(pixel_values)
|
|
|
250 |
|
251 |
if vision_hidden_states is None:
|
252 |
model_inputs["pixel_values"] = pixel_values
|
253 |
+
model_inputs["tgt_sizes"] = tgt_sizes
|
254 |
else:
|
255 |
model_inputs["vision_hidden_states"] = vision_hidden_states
|
256 |
|
|
|
263 |
if stream:
|
264 |
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
|
265 |
else:
|
266 |
+
result = self._decode(
|
267 |
+
model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs
|
268 |
+
)
|
269 |
|
270 |
if return_vision_hidden_states:
|
271 |
return result, vision_hidden_states
|
272 |
+
|
273 |
return result
|
274 |
|
275 |
def chat(
|
|
|
283 |
min_new_tokens=0,
|
284 |
sampling=True,
|
285 |
max_inp_length=8192,
|
286 |
+
system_prompt="",
|
287 |
stream=False,
|
288 |
max_slice_nums=None,
|
289 |
use_image_id=None,
|
290 |
+
**kwargs,
|
291 |
):
|
292 |
if isinstance(msgs[0], list):
|
293 |
batched = True
|
|
|
295 |
batched = False
|
296 |
msgs_list = msgs
|
297 |
images_list = image
|
298 |
+
|
299 |
if batched is False:
|
300 |
images_list, msgs_list = [images_list], [msgs_list]
|
301 |
else:
|
|
|
307 |
if self.processor is None:
|
308 |
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
309 |
processor = self.processor
|
310 |
+
|
311 |
+
assert (
|
312 |
+
self.config.query_num == processor.image_processor.image_feature_size
|
313 |
+
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
314 |
+
assert (
|
315 |
+
self.config.patch_size == processor.image_processor.patch_size
|
316 |
+
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
317 |
+
assert (
|
318 |
+
self.config.use_image_id == processor.image_processor.use_image_id
|
319 |
+
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
320 |
+
assert (
|
321 |
+
self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
|
322 |
+
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
323 |
+
assert (
|
324 |
+
self.config.slice_mode == processor.image_processor.slice_mode
|
325 |
+
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
|
326 |
|
327 |
prompts_lists = []
|
328 |
input_images_lists = []
|
|
|
356 |
msg["content"] = "\n".join(cur_msgs)
|
357 |
|
358 |
if system_prompt:
|
359 |
+
sys_msg = {"role": "system", "content": system_prompt}
|
360 |
+
copy_msgs = [sys_msg] + copy_msgs
|
361 |
|
362 |
+
prompts_lists.append(
|
363 |
+
processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
|
364 |
+
)
|
365 |
input_images_lists.append(images)
|
366 |
|
367 |
inputs = processor(
|
368 |
+
prompts_lists,
|
369 |
+
input_images_lists,
|
370 |
max_slice_nums=max_slice_nums,
|
371 |
use_image_id=use_image_id,
|
372 |
+
return_tensors="pt",
|
373 |
+
max_length=max_inp_length,
|
374 |
).to(self.device)
|
375 |
|
376 |
if sampling:
|
|
|
379 |
"top_k": 100,
|
380 |
"temperature": 0.7,
|
381 |
"do_sample": True,
|
382 |
+
"repetition_penalty": 1.05,
|
383 |
}
|
384 |
else:
|
385 |
generation_config = {
|
386 |
"num_beams": 3,
|
387 |
"repetition_penalty": 1.2,
|
388 |
}
|
389 |
+
|
390 |
if min_new_tokens > 0:
|
391 |
+
generation_config["min_new_tokens"] = min_new_tokens
|
392 |
|
393 |
+
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
|
|
|
|
|
394 |
|
395 |
inputs.pop("image_sizes")
|
396 |
with torch.inference_mode():
|
|
|
401 |
vision_hidden_states=vision_hidden_states,
|
402 |
stream=stream,
|
403 |
decode_text=True,
|
404 |
+
**generation_config,
|
405 |
)
|
406 |
+
|
407 |
if stream:
|
408 |
+
|
409 |
def stream_gen():
|
410 |
for text in res:
|
411 |
for term in self.terminators:
|
412 |
+
text = text.replace(term, "")
|
413 |
yield text
|
414 |
+
|
415 |
return stream_gen()
|
416 |
|
417 |
else:
|
modeling_navit_siglip.py
CHANGED
@@ -16,11 +16,11 @@
|
|
16 |
# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
|
17 |
|
18 |
|
19 |
-
import os
|
20 |
import math
|
|
|
21 |
import warnings
|
22 |
from dataclasses import dataclass
|
23 |
-
from typing import
|
24 |
|
25 |
import numpy as np
|
26 |
import torch
|
@@ -28,12 +28,11 @@ import torch.nn.functional as F
|
|
28 |
import torch.utils.checkpoint
|
29 |
from torch import nn
|
30 |
from torch.nn.init import _calculate_fan_in_and_fan_out
|
31 |
-
|
32 |
from transformers.activations import ACT2FN
|
|
|
33 |
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
34 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
36 |
-
from transformers.configuration_utils import PretrainedConfig
|
37 |
from transformers.utils import (
|
38 |
ModelOutput,
|
39 |
add_start_docstrings,
|
@@ -42,10 +41,11 @@ from transformers.utils import (
|
|
42 |
logging,
|
43 |
replace_return_docstrings,
|
44 |
)
|
45 |
-
|
46 |
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
|
|
49 |
class SiglipVisionConfig(PretrainedConfig):
|
50 |
r"""
|
51 |
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
@@ -133,7 +133,7 @@ class SiglipVisionConfig(PretrainedConfig):
|
|
133 |
)
|
134 |
|
135 |
return cls.from_dict(config_dict, **kwargs)
|
136 |
-
|
137 |
|
138 |
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
139 |
|
@@ -148,7 +148,6 @@ try:
|
|
148 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
149 |
except:
|
150 |
pass
|
151 |
-
|
152 |
|
153 |
|
154 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
@@ -318,7 +317,12 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
318 |
self.num_positions = self.num_patches
|
319 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
320 |
|
321 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
322 |
batch_size = pixel_values.size(0)
|
323 |
|
324 |
patch_embeds = self.patch_embedding(pixel_values)
|
@@ -643,11 +647,7 @@ class SiglipEncoderLayer(nn.Module):
|
|
643 |
super().__init__()
|
644 |
self.embed_dim = config.hidden_size
|
645 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
646 |
-
self.self_attn = (
|
647 |
-
SiglipAttention(config)
|
648 |
-
if not self._use_flash_attention_2
|
649 |
-
else SiglipFlashAttention2(config)
|
650 |
-
)
|
651 |
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
652 |
self.mlp = SiglipMLP(config)
|
653 |
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
@@ -847,9 +847,9 @@ class SiglipEncoder(nn.Module):
|
|
847 |
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
848 |
)
|
849 |
|
|
|
850 |
@add_start_docstrings(
|
851 |
-
"""The vision model from SigLIP without any head or projection on top.""",
|
852 |
-
SIGLIP_START_DOCSTRING
|
853 |
)
|
854 |
class SiglipVisionTransformer(SiglipPreTrainedModel):
|
855 |
config_class = SiglipVisionConfig
|
@@ -904,14 +904,16 @@ class SiglipVisionTransformer(SiglipPreTrainedModel):
|
|
904 |
device=pixel_values.device,
|
905 |
)
|
906 |
|
907 |
-
hidden_states = self.embeddings(
|
|
|
|
|
908 |
|
909 |
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
910 |
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
911 |
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
912 |
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
913 |
if not torch.any(~patch_attention_mask):
|
914 |
-
attention_mask=None
|
915 |
else:
|
916 |
attention_mask = (
|
917 |
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
|
|
16 |
# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
|
17 |
|
18 |
|
|
|
19 |
import math
|
20 |
+
import os
|
21 |
import warnings
|
22 |
from dataclasses import dataclass
|
23 |
+
from typing import Optional, Tuple, Union
|
24 |
|
25 |
import numpy as np
|
26 |
import torch
|
|
|
28 |
import torch.utils.checkpoint
|
29 |
from torch import nn
|
30 |
from torch.nn.init import _calculate_fan_in_and_fan_out
|
|
|
31 |
from transformers.activations import ACT2FN
|
32 |
+
from transformers.configuration_utils import PretrainedConfig
|
33 |
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
34 |
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
35 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
36 |
from transformers.utils import (
|
37 |
ModelOutput,
|
38 |
add_start_docstrings,
|
|
|
41 |
logging,
|
42 |
replace_return_docstrings,
|
43 |
)
|
44 |
+
|
45 |
|
46 |
logger = logging.get_logger(__name__)
|
47 |
|
48 |
+
|
49 |
class SiglipVisionConfig(PretrainedConfig):
|
50 |
r"""
|
51 |
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
|
|
133 |
)
|
134 |
|
135 |
return cls.from_dict(config_dict, **kwargs)
|
136 |
+
|
137 |
|
138 |
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
139 |
|
|
|
148 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
149 |
except:
|
150 |
pass
|
|
|
151 |
|
152 |
|
153 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
317 |
self.num_positions = self.num_patches
|
318 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
319 |
|
320 |
+
def forward(
|
321 |
+
self,
|
322 |
+
pixel_values: torch.FloatTensor,
|
323 |
+
patch_attention_mask: torch.BoolTensor,
|
324 |
+
tgt_sizes: Optional[torch.IntTensor] = None,
|
325 |
+
) -> torch.Tensor:
|
326 |
batch_size = pixel_values.size(0)
|
327 |
|
328 |
patch_embeds = self.patch_embedding(pixel_values)
|
|
|
647 |
super().__init__()
|
648 |
self.embed_dim = config.hidden_size
|
649 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
650 |
+
self.self_attn = SiglipAttention(config) if not self._use_flash_attention_2 else SiglipFlashAttention2(config)
|
|
|
|
|
|
|
|
|
651 |
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
652 |
self.mlp = SiglipMLP(config)
|
653 |
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
|
847 |
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
848 |
)
|
849 |
|
850 |
+
|
851 |
@add_start_docstrings(
|
852 |
+
"""The vision model from SigLIP without any head or projection on top.""", SIGLIP_START_DOCSTRING
|
|
|
853 |
)
|
854 |
class SiglipVisionTransformer(SiglipPreTrainedModel):
|
855 |
config_class = SiglipVisionConfig
|
|
|
904 |
device=pixel_values.device,
|
905 |
)
|
906 |
|
907 |
+
hidden_states = self.embeddings(
|
908 |
+
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes
|
909 |
+
)
|
910 |
|
911 |
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
912 |
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
913 |
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
914 |
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
915 |
if not torch.any(~patch_attention_mask):
|
916 |
+
attention_mask = None
|
917 |
else:
|
918 |
attention_mask = (
|
919 |
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
processing_minicpmv.py
CHANGED
@@ -16,15 +16,14 @@
|
|
16 |
Processor class for MiniCPMV.
|
17 |
"""
|
18 |
|
19 |
-
from typing import List, Optional, Union, Dict, Any
|
20 |
-
import torch
|
21 |
import re
|
|
|
22 |
|
23 |
-
|
24 |
from transformers.image_utils import ImageInput
|
25 |
from transformers.processing_utils import ProcessorMixin
|
26 |
-
from transformers.tokenization_utils_base import
|
27 |
-
from transformers.utils import TensorType
|
28 |
|
29 |
from .image_processing_minicpmv import MiniCPMVBatchFeature
|
30 |
|
@@ -49,7 +48,7 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
49 |
def __init__(self, image_processor=None, tokenizer=None):
|
50 |
super().__init__(image_processor, tokenizer)
|
51 |
self.version = image_processor.version
|
52 |
-
|
53 |
def __call__(
|
54 |
self,
|
55 |
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
@@ -59,14 +58,23 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
59 |
max_slice_nums: int = None,
|
60 |
use_image_id: bool = None,
|
61 |
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
62 |
-
**kwargs
|
63 |
) -> MiniCPMVBatchFeature:
|
64 |
-
|
65 |
image_inputs = None
|
66 |
if images is not None:
|
67 |
-
image_inputs = self.image_processor(
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
71 |
def batch_decode(self, *args, **kwargs):
|
72 |
"""
|
@@ -84,7 +92,7 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
84 |
result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
|
85 |
return result_text
|
86 |
# return self.tokenizer.batch_decode(*args, **kwargs)
|
87 |
-
|
88 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
89 |
def decode(self, *args, **kwargs):
|
90 |
"""
|
@@ -95,13 +103,13 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
95 |
result = result[result != 0]
|
96 |
if result[0] == self.tokenizer.bos_id:
|
97 |
result = result[1:]
|
98 |
-
if result[-1] == self.tokenizer.eos_id or (
|
|
|
|
|
99 |
result = result[:-1]
|
100 |
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
|
101 |
|
102 |
-
def _convert(
|
103 |
-
self, input_str, max_inp_length: Optional[int] = None
|
104 |
-
):
|
105 |
if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
|
106 |
input_ids = self.tokenizer.encode(input_str)
|
107 |
else:
|
@@ -128,23 +136,25 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
128 |
return input_ids, image_bounds
|
129 |
|
130 |
def _convert_images_texts_to_inputs(
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
if images is None or not len(images):
|
142 |
-
model_inputs = self.tokenizer(
|
|
|
|
|
143 |
return MiniCPMVBatchFeature(data={**model_inputs})
|
144 |
-
|
145 |
pattern = "(<image>./</image>)"
|
146 |
images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
|
147 |
-
|
148 |
if isinstance(texts, str):
|
149 |
texts = [texts]
|
150 |
input_ids_list = []
|
@@ -155,33 +165,32 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
155 |
text_chunks = text.split(pattern)
|
156 |
final_text = ""
|
157 |
for i in range(len(image_tags)):
|
158 |
-
final_text =
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
max_slice_nums,
|
163 |
-
use_image_id
|
164 |
)
|
|
|
165 |
final_text += text_chunks[-1]
|
166 |
input_ids, image_bounds = self._convert(final_text, max_length)
|
167 |
input_ids_list.append(input_ids)
|
168 |
image_bounds_list.append(image_bounds)
|
169 |
-
padded_input_ids, padding_lengths = self.pad(
|
170 |
-
input_ids_list,
|
171 |
-
padding_side="left"
|
172 |
-
)
|
173 |
for i, length in enumerate(padding_lengths):
|
174 |
image_bounds_list[i] = image_bounds_list[i] + length
|
175 |
attention_mask = padded_input_ids.ne(0)
|
176 |
|
177 |
-
return MiniCPMVBatchFeature(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
185 |
|
186 |
@property
|
187 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
@@ -190,7 +199,6 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
190 |
image_processor_input_names = self.image_processor.model_input_names
|
191 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
192 |
|
193 |
-
|
194 |
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
|
195 |
items = []
|
196 |
if isinstance(inputs[0], list):
|
@@ -219,10 +227,7 @@ class MiniCPMVProcessor(ProcessorMixin):
|
|
219 |
return torch.stack([item for item in items], dim=0), [0] * batch_size
|
220 |
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
221 |
else:
|
222 |
-
tensor = (
|
223 |
-
torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype)
|
224 |
-
+ padding_value
|
225 |
-
)
|
226 |
|
227 |
padding_length = []
|
228 |
for i, item in enumerate(items):
|
|
|
16 |
Processor class for MiniCPMV.
|
17 |
"""
|
18 |
|
|
|
|
|
19 |
import re
|
20 |
+
from typing import List, Optional, Union
|
21 |
|
22 |
+
import torch
|
23 |
from transformers.image_utils import ImageInput
|
24 |
from transformers.processing_utils import ProcessorMixin
|
25 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
26 |
+
from transformers.utils import TensorType
|
27 |
|
28 |
from .image_processing_minicpmv import MiniCPMVBatchFeature
|
29 |
|
|
|
48 |
def __init__(self, image_processor=None, tokenizer=None):
|
49 |
super().__init__(image_processor, tokenizer)
|
50 |
self.version = image_processor.version
|
51 |
+
|
52 |
def __call__(
|
53 |
self,
|
54 |
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
|
|
|
58 |
max_slice_nums: int = None,
|
59 |
use_image_id: bool = None,
|
60 |
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
61 |
+
**kwargs,
|
62 |
) -> MiniCPMVBatchFeature:
|
|
|
63 |
image_inputs = None
|
64 |
if images is not None:
|
65 |
+
image_inputs = self.image_processor(
|
66 |
+
images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
|
67 |
+
)
|
68 |
+
return self._convert_images_texts_to_inputs(
|
69 |
+
image_inputs,
|
70 |
+
text,
|
71 |
+
max_slice_nums=max_slice_nums,
|
72 |
+
use_image_id=use_image_id,
|
73 |
+
max_length=max_length,
|
74 |
+
**kwargs,
|
75 |
+
return_tensors=return_tensors,
|
76 |
+
)
|
77 |
+
|
78 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
|
79 |
def batch_decode(self, *args, **kwargs):
|
80 |
"""
|
|
|
92 |
result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip())
|
93 |
return result_text
|
94 |
# return self.tokenizer.batch_decode(*args, **kwargs)
|
95 |
+
|
96 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
|
97 |
def decode(self, *args, **kwargs):
|
98 |
"""
|
|
|
103 |
result = result[result != 0]
|
104 |
if result[0] == self.tokenizer.bos_id:
|
105 |
result = result[1:]
|
106 |
+
if result[-1] == self.tokenizer.eos_id or (
|
107 |
+
hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id
|
108 |
+
):
|
109 |
result = result[:-1]
|
110 |
return self.tokenizer.decode(result, *args[1:], **kwargs).strip()
|
111 |
|
112 |
+
def _convert(self, input_str, max_inp_length: Optional[int] = None):
|
|
|
|
|
113 |
if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False):
|
114 |
input_ids = self.tokenizer.encode(input_str)
|
115 |
else:
|
|
|
136 |
return input_ids, image_bounds
|
137 |
|
138 |
def _convert_images_texts_to_inputs(
|
139 |
+
self,
|
140 |
+
images,
|
141 |
+
texts: Union[str, List[str]],
|
142 |
+
truncation=None,
|
143 |
+
max_length=None,
|
144 |
+
max_slice_nums=None,
|
145 |
+
use_image_id=None,
|
146 |
+
return_tensors=None,
|
147 |
+
**kwargs,
|
148 |
+
):
|
149 |
if images is None or not len(images):
|
150 |
+
model_inputs = self.tokenizer(
|
151 |
+
texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs
|
152 |
+
)
|
153 |
return MiniCPMVBatchFeature(data={**model_inputs})
|
154 |
+
|
155 |
pattern = "(<image>./</image>)"
|
156 |
images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"]
|
157 |
+
|
158 |
if isinstance(texts, str):
|
159 |
texts = [texts]
|
160 |
input_ids_list = []
|
|
|
165 |
text_chunks = text.split(pattern)
|
166 |
final_text = ""
|
167 |
for i in range(len(image_tags)):
|
168 |
+
final_text = (
|
169 |
+
final_text
|
170 |
+
+ text_chunks[i]
|
171 |
+
+ self.image_processor.get_slice_image_placeholder(
|
172 |
+
image_sizes[index][i], i, max_slice_nums, use_image_id
|
|
|
173 |
)
|
174 |
+
)
|
175 |
final_text += text_chunks[-1]
|
176 |
input_ids, image_bounds = self._convert(final_text, max_length)
|
177 |
input_ids_list.append(input_ids)
|
178 |
image_bounds_list.append(image_bounds)
|
179 |
+
padded_input_ids, padding_lengths = self.pad(input_ids_list, padding_side="left")
|
|
|
|
|
|
|
180 |
for i, length in enumerate(padding_lengths):
|
181 |
image_bounds_list[i] = image_bounds_list[i] + length
|
182 |
attention_mask = padded_input_ids.ne(0)
|
183 |
|
184 |
+
return MiniCPMVBatchFeature(
|
185 |
+
data={
|
186 |
+
"input_ids": padded_input_ids,
|
187 |
+
"attention_mask": attention_mask,
|
188 |
+
"pixel_values": images,
|
189 |
+
"image_sizes": image_sizes,
|
190 |
+
"image_bound": image_bounds_list,
|
191 |
+
"tgt_sizes": tgt_sizes,
|
192 |
+
}
|
193 |
+
)
|
194 |
|
195 |
@property
|
196 |
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
|
|
|
199 |
image_processor_input_names = self.image_processor.model_input_names
|
200 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
201 |
|
|
|
202 |
def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"):
|
203 |
items = []
|
204 |
if isinstance(inputs[0], list):
|
|
|
227 |
return torch.stack([item for item in items], dim=0), [0] * batch_size
|
228 |
tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value
|
229 |
else:
|
230 |
+
tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value
|
|
|
|
|
|
|
231 |
|
232 |
padding_length = []
|
233 |
for i, item in enumerate(items):
|
resampler.py
CHANGED
@@ -1,18 +1,17 @@
|
|
|
|
1 |
from functools import partial
|
2 |
from typing import Optional, Tuple
|
3 |
-
import numpy as np
|
4 |
-
import warnings
|
5 |
|
|
|
6 |
import torch
|
7 |
-
from torch import nn
|
8 |
-
from torch import Tensor
|
9 |
import torch.nn.functional as F
|
|
|
10 |
from torch.nn.functional import *
|
|
|
11 |
from torch.nn.modules.activation import *
|
12 |
-
from torch.nn.init import trunc_normal_, constant_, xavier_normal_, xavier_uniform_
|
13 |
-
|
14 |
from transformers.integrations import is_deepspeed_zero3_enabled
|
15 |
|
|
|
16 |
def get_2d_sincos_pos_embed(embed_dim, image_size):
|
17 |
"""
|
18 |
image_size: image_size or (image_height, image_width)
|
@@ -52,10 +51,10 @@ def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
|
|
52 |
"""
|
53 |
assert embed_dim % 2 == 0
|
54 |
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
55 |
-
omega /= embed_dim / 2.
|
56 |
-
omega = 1. / 10000
|
57 |
|
58 |
-
out = np.einsum(
|
59 |
|
60 |
emb_sin = np.sin(out) # (H, W, D/2)
|
61 |
emb_cos = np.cos(out) # (H, W, D/2)
|
@@ -73,14 +72,14 @@ class Resampler(nn.Module):
|
|
73 |
"""
|
74 |
|
75 |
def __init__(
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
):
|
85 |
super().__init__()
|
86 |
self.num_queries = num_queries
|
@@ -101,13 +100,13 @@ class Resampler(nn.Module):
|
|
101 |
self.ln_kv = norm_layer(embed_dim)
|
102 |
|
103 |
self.ln_post = norm_layer(embed_dim)
|
104 |
-
self.proj = nn.Parameter((embed_dim
|
105 |
|
106 |
self._set_2d_pos_cache(self.max_size)
|
107 |
|
108 |
-
def _set_2d_pos_cache(self, max_size, device=
|
109 |
if is_deepspeed_zero3_enabled():
|
110 |
-
device=
|
111 |
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
|
112 |
self.register_buffer("pos_embed", pos_embed, persistent=False)
|
113 |
|
@@ -120,7 +119,7 @@ class Resampler(nn.Module):
|
|
120 |
|
121 |
def _init_weights(self, m):
|
122 |
if isinstance(m, nn.Linear):
|
123 |
-
trunc_normal_(m.weight, std
|
124 |
if isinstance(m, nn.Linear) and m.bias is not None:
|
125 |
nn.init.constant_(m.bias, 0)
|
126 |
elif isinstance(m, nn.LayerNorm):
|
@@ -145,10 +144,11 @@ class Resampler(nn.Module):
|
|
145 |
for i in range(bs):
|
146 |
tgt_h, tgt_w = tgt_sizes[i]
|
147 |
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
|
148 |
-
key_padding_mask[i, patch_len[i]:] = True
|
149 |
|
150 |
-
pos_embed = torch.nn.utils.rnn.pad_sequence(
|
151 |
-
|
|
|
152 |
|
153 |
x = self.kv_proj(x) # B * L * D
|
154 |
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
|
@@ -159,7 +159,8 @@ class Resampler(nn.Module):
|
|
159 |
self._repeat(q, bs), # Q * B * D
|
160 |
x + pos_embed, # L * B * D + L * B * D
|
161 |
x,
|
162 |
-
key_padding_mask=key_padding_mask
|
|
|
163 |
# out: Q * B * D
|
164 |
x = out.permute(1, 0, 2) # B * Q * D
|
165 |
|
@@ -172,26 +173,44 @@ class Resampler(nn.Module):
|
|
172 |
|
173 |
|
174 |
class MultiheadAttention(nn.MultiheadAttention):
|
175 |
-
def __init__(
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
# rewrite out_proj layer,with nn.Linear
|
180 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
181 |
|
182 |
def forward(
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
|
|
|
|
195 |
why_not_fast_path = "floating-point masks are not supported for fast path."
|
196 |
|
197 |
is_batched = query.dim() == 3
|
@@ -201,7 +220,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
201 |
mask_name="key_padding_mask",
|
202 |
other_type=F._none_or_dtype(attn_mask),
|
203 |
other_name="attn_mask",
|
204 |
-
target_type=query.dtype
|
205 |
)
|
206 |
|
207 |
attn_mask = _canonical_mask(
|
@@ -213,7 +232,6 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
213 |
check_other=False,
|
214 |
)
|
215 |
|
216 |
-
|
217 |
if not is_batched:
|
218 |
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
219 |
elif query is not key or key is not value:
|
@@ -222,12 +240,16 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
222 |
# they don't!
|
223 |
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
224 |
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
225 |
-
why_not_fast_path =
|
|
|
|
|
226 |
elif self.in_proj_weight is None:
|
227 |
why_not_fast_path = "in_proj_weight was None"
|
228 |
elif query.dtype != self.in_proj_weight.dtype:
|
229 |
# this case will fail anyway, but at least they'll get a useful error message.
|
230 |
-
why_not_fast_path =
|
|
|
|
|
231 |
elif self.training:
|
232 |
why_not_fast_path = "training is enabled"
|
233 |
elif (self.num_heads % 2) != 0:
|
@@ -265,11 +287,15 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
265 |
elif _is_make_fx_tracing():
|
266 |
why_not_fast_path = "we are running make_fx tracing"
|
267 |
elif not all(_check_arg_device(x) for x in tensor_args):
|
268 |
-
why_not_fast_path = (
|
269 |
-
|
|
|
|
|
270 |
elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
|
271 |
-
why_not_fast_path = (
|
272 |
-
|
|
|
|
|
273 |
if not why_not_fast_path:
|
274 |
merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
|
275 |
|
@@ -287,11 +313,14 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
287 |
merged_mask,
|
288 |
need_weights,
|
289 |
average_attn_weights,
|
290 |
-
mask_type
|
|
|
291 |
|
292 |
any_nested = query.is_nested or key.is_nested or value.is_nested
|
293 |
-
assert not any_nested, (
|
294 |
-
|
|
|
|
|
295 |
|
296 |
if self.batch_first and is_batched:
|
297 |
# make sure that the transpose op does not affect the "is" property
|
@@ -303,38 +332,60 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
303 |
value = key
|
304 |
else:
|
305 |
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
306 |
-
|
307 |
if not self._qkv_same_embed_dim:
|
308 |
attn_output, attn_output_weights = self.multi_head_attention_forward(
|
309 |
-
query,
|
310 |
-
|
311 |
-
|
312 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
training=self.training,
|
314 |
-
key_padding_mask=key_padding_mask,
|
|
|
315 |
attn_mask=attn_mask,
|
316 |
use_separate_proj_weight=True,
|
317 |
-
q_proj_weight=self.q_proj_weight,
|
|
|
318 |
v_proj_weight=self.v_proj_weight,
|
319 |
average_attn_weights=average_attn_weights,
|
320 |
-
is_causal=is_causal
|
|
|
321 |
else:
|
322 |
attn_output, attn_output_weights = self.multi_head_attention_forward(
|
323 |
-
query,
|
324 |
-
|
325 |
-
|
326 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
training=self.training,
|
328 |
key_padding_mask=key_padding_mask,
|
329 |
need_weights=need_weights,
|
330 |
attn_mask=attn_mask,
|
331 |
average_attn_weights=average_attn_weights,
|
332 |
-
is_causal=is_causal
|
|
|
333 |
if self.batch_first and is_batched:
|
334 |
return attn_output.transpose(1, 0), attn_output_weights
|
335 |
else:
|
336 |
return attn_output, attn_output_weights
|
337 |
-
|
338 |
def multi_head_attention_forward(
|
339 |
self,
|
340 |
query: Tensor,
|
@@ -364,9 +415,9 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
364 |
is_causal: bool = False,
|
365 |
) -> Tuple[Tensor, Optional[Tensor]]:
|
366 |
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
367 |
-
|
368 |
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
369 |
-
|
370 |
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
371 |
# is batched, run the computation and before returning squeeze the
|
372 |
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
@@ -377,26 +428,26 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
377 |
value = value.unsqueeze(1)
|
378 |
if key_padding_mask is not None:
|
379 |
key_padding_mask = key_padding_mask.unsqueeze(0)
|
380 |
-
|
381 |
# set up shape vars
|
382 |
tgt_len, bsz, embed_dim = query.shape
|
383 |
src_len, _, _ = key.shape
|
384 |
-
|
385 |
key_padding_mask = _canonical_mask(
|
386 |
mask=key_padding_mask,
|
387 |
mask_name="key_padding_mask",
|
388 |
other_type=_none_or_dtype(attn_mask),
|
389 |
other_name="attn_mask",
|
390 |
-
target_type=query.dtype
|
391 |
)
|
392 |
-
|
393 |
if is_causal and attn_mask is None:
|
394 |
raise RuntimeError(
|
395 |
"Need attn_mask if specifying the is_causal hint. "
|
396 |
"You may use the Transformer module method "
|
397 |
"`generate_square_subsequent_mask` to create this mask."
|
398 |
)
|
399 |
-
|
400 |
if is_causal and key_padding_mask is None and not need_weights:
|
401 |
# when we have a kpm or need weights, we need attn_mask
|
402 |
# Otherwise, we use the is_causal hint go as is_causal
|
@@ -411,28 +462,30 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
411 |
target_type=query.dtype,
|
412 |
check_other=False,
|
413 |
)
|
414 |
-
|
415 |
if key_padding_mask is not None:
|
416 |
# We have the attn_mask, and use that to merge kpm into it.
|
417 |
# Turn off use of is_causal hint, as the merged mask is no
|
418 |
# longer causal.
|
419 |
is_causal = False
|
420 |
-
|
421 |
-
assert
|
422 |
-
|
|
|
423 |
if isinstance(embed_dim, torch.Tensor):
|
424 |
# embed_dim can be a tensor when JIT tracing
|
425 |
-
head_dim = embed_dim.div(num_heads, rounding_mode=
|
426 |
else:
|
427 |
head_dim = embed_dim // num_heads
|
428 |
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
429 |
if use_separate_proj_weight:
|
430 |
# allow MHA to have different embedding dimensions when separate projection weights are used
|
431 |
-
assert
|
432 |
-
|
|
|
433 |
else:
|
434 |
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
435 |
-
|
436 |
#
|
437 |
# compute in-projection
|
438 |
#
|
@@ -448,23 +501,27 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
448 |
else:
|
449 |
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
450 |
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
|
451 |
-
|
452 |
# prep attention mask
|
453 |
-
|
454 |
if attn_mask is not None:
|
455 |
# ensure attn_mask's dim is 3
|
456 |
if attn_mask.dim() == 2:
|
457 |
correct_2d_size = (tgt_len, src_len)
|
458 |
if attn_mask.shape != correct_2d_size:
|
459 |
-
raise RuntimeError(
|
|
|
|
|
460 |
attn_mask = attn_mask.unsqueeze(0)
|
461 |
elif attn_mask.dim() == 3:
|
462 |
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
463 |
if attn_mask.shape != correct_3d_size:
|
464 |
-
raise RuntimeError(
|
|
|
|
|
465 |
else:
|
466 |
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
467 |
-
|
468 |
# add bias along batch dimension (currently second)
|
469 |
if bias_k is not None and bias_v is not None:
|
470 |
assert static_k is None, "bias cannot be added to static key."
|
@@ -478,7 +535,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
478 |
else:
|
479 |
assert bias_k is None
|
480 |
assert bias_v is None
|
481 |
-
|
482 |
#
|
483 |
# reshape q, k, v for multihead attention and make em batch first
|
484 |
#
|
@@ -487,21 +544,25 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
487 |
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
488 |
else:
|
489 |
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
490 |
-
assert
|
491 |
-
|
492 |
-
|
493 |
-
|
|
|
|
|
494 |
k = static_k
|
495 |
if static_v is None:
|
496 |
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
497 |
else:
|
498 |
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
499 |
-
assert
|
500 |
-
|
501 |
-
|
502 |
-
|
|
|
|
|
503 |
v = static_v
|
504 |
-
|
505 |
# add zero attention along batch dimension (now first)
|
506 |
if add_zero_attn:
|
507 |
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
@@ -511,35 +572,40 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
511 |
attn_mask = pad(attn_mask, (0, 1))
|
512 |
if key_padding_mask is not None:
|
513 |
key_padding_mask = pad(key_padding_mask, (0, 1))
|
514 |
-
|
515 |
# update source sequence length after adjustments
|
516 |
src_len = k.size(1)
|
517 |
-
|
518 |
# merge key padding and attention masks
|
519 |
if key_padding_mask is not None:
|
520 |
-
assert key_padding_mask.shape == (
|
521 |
-
|
522 |
-
|
523 |
-
|
|
|
|
|
|
|
|
|
|
|
524 |
if attn_mask is None:
|
525 |
attn_mask = key_padding_mask
|
526 |
else:
|
527 |
attn_mask = attn_mask + key_padding_mask
|
528 |
-
|
529 |
# adjust dropout probability
|
530 |
if not training:
|
531 |
dropout_p = 0.0
|
532 |
-
|
533 |
#
|
534 |
# (deep breath) calculate attention and out projection
|
535 |
#
|
536 |
-
|
537 |
if need_weights:
|
538 |
B, Nt, E = q.shape
|
539 |
q_scaled = q / math.sqrt(E)
|
540 |
-
|
541 |
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
542 |
-
|
543 |
if attn_mask is not None:
|
544 |
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
545 |
else:
|
@@ -547,18 +613,18 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
547 |
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
548 |
if dropout_p > 0.0:
|
549 |
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
550 |
-
|
551 |
attn_output = torch.bmm(attn_output_weights, v)
|
552 |
-
|
553 |
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
554 |
attn_output = self.out_proj(attn_output)
|
555 |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
556 |
-
|
557 |
# optionally average attention weights over heads
|
558 |
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
559 |
if average_attn_weights:
|
560 |
attn_output_weights = attn_output_weights.mean(dim=1)
|
561 |
-
|
562 |
if not is_batched:
|
563 |
# squeeze the output if input was unbatched
|
564 |
attn_output = attn_output.squeeze(1)
|
@@ -573,14 +639,14 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
573 |
attn_mask = attn_mask.unsqueeze(0)
|
574 |
else:
|
575 |
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
576 |
-
|
577 |
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
578 |
k = k.view(bsz, num_heads, src_len, head_dim)
|
579 |
v = v.view(bsz, num_heads, src_len, head_dim)
|
580 |
-
|
581 |
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
582 |
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
583 |
-
|
584 |
attn_output = self.out_proj(attn_output)
|
585 |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
586 |
if not is_batched:
|
@@ -589,8 +655,14 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|
589 |
return attn_output, None
|
590 |
|
591 |
|
592 |
-
def _mha_shape_check(
|
593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
|
595 |
# and returns if the input is batched or not.
|
596 |
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
|
@@ -599,59 +671,65 @@ def _mha_shape_check(query: Tensor, key: Tensor, value: Tensor,
|
|
599 |
if query.dim() == 3:
|
600 |
# Batched Inputs
|
601 |
is_batched = True
|
602 |
-
assert key.dim() == 3 and value.dim() == 3,
|
603 |
-
|
604 |
-
|
|
|
605 |
if key_padding_mask is not None:
|
606 |
-
assert key_padding_mask.dim() == 2,
|
607 |
-
|
608 |
-
|
|
|
609 |
if attn_mask is not None:
|
610 |
-
assert attn_mask.dim() in (2, 3),
|
611 |
-
|
612 |
-
|
|
|
613 |
elif query.dim() == 2:
|
614 |
# Unbatched Inputs
|
615 |
is_batched = False
|
616 |
-
assert key.dim() == 2 and value.dim() == 2,
|
617 |
-
|
618 |
-
|
|
|
619 |
|
620 |
if key_padding_mask is not None:
|
621 |
-
assert key_padding_mask.dim() == 1,
|
622 |
-
|
623 |
-
|
|
|
624 |
|
625 |
if attn_mask is not None:
|
626 |
-
assert attn_mask.dim() in (2, 3),
|
627 |
-
|
628 |
-
|
|
|
629 |
if attn_mask.dim() == 3:
|
630 |
expected_shape = (num_heads, query.shape[0], key.shape[0])
|
631 |
-
assert
|
632 |
-
|
|
|
633 |
else:
|
634 |
raise AssertionError(
|
635 |
-
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
|
|
|
636 |
|
637 |
return is_batched
|
638 |
|
639 |
|
640 |
def _canonical_mask(
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
) -> Optional[Tensor]:
|
648 |
-
|
649 |
if mask is not None:
|
650 |
_mask_dtype = mask.dtype
|
651 |
_mask_is_float = torch.is_floating_point(mask)
|
652 |
if _mask_dtype != torch.bool and not _mask_is_float:
|
653 |
-
raise AssertionError(
|
654 |
-
f"only bool and floating types of {mask_name} are supported")
|
655 |
if check_other and other_type is not None:
|
656 |
if _mask_dtype != other_type:
|
657 |
warnings.warn(
|
@@ -659,10 +737,7 @@ def _canonical_mask(
|
|
659 |
"is deprecated. Use same type for both instead."
|
660 |
)
|
661 |
if not _mask_is_float:
|
662 |
-
mask = (
|
663 |
-
torch.zeros_like(mask, dtype=target_type)
|
664 |
-
.masked_fill_(mask, float("-inf"))
|
665 |
-
)
|
666 |
return mask
|
667 |
|
668 |
|
@@ -673,6 +748,7 @@ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
|
|
673 |
return input.dtype
|
674 |
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
675 |
|
|
|
676 |
def _in_projection_packed(
|
677 |
q: Tensor,
|
678 |
k: Tensor,
|
@@ -779,4 +855,4 @@ def _in_projection(
|
|
779 |
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
|
780 |
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
|
781 |
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
|
782 |
-
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
|
|
1 |
+
import warnings
|
2 |
from functools import partial
|
3 |
from typing import Optional, Tuple
|
|
|
|
|
4 |
|
5 |
+
import numpy as np
|
6 |
import torch
|
|
|
|
|
7 |
import torch.nn.functional as F
|
8 |
+
from torch import Tensor, nn
|
9 |
from torch.nn.functional import *
|
10 |
+
from torch.nn.init import trunc_normal_
|
11 |
from torch.nn.modules.activation import *
|
|
|
|
|
12 |
from transformers.integrations import is_deepspeed_zero3_enabled
|
13 |
|
14 |
+
|
15 |
def get_2d_sincos_pos_embed(embed_dim, image_size):
|
16 |
"""
|
17 |
image_size: image_size or (image_height, image_width)
|
|
|
51 |
"""
|
52 |
assert embed_dim % 2 == 0
|
53 |
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
54 |
+
omega /= embed_dim / 2.0
|
55 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
56 |
|
57 |
+
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
|
58 |
|
59 |
emb_sin = np.sin(out) # (H, W, D/2)
|
60 |
emb_cos = np.cos(out) # (H, W, D/2)
|
|
|
72 |
"""
|
73 |
|
74 |
def __init__(
|
75 |
+
self,
|
76 |
+
num_queries,
|
77 |
+
embed_dim,
|
78 |
+
num_heads,
|
79 |
+
kv_dim=None,
|
80 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
81 |
+
adaptive=False,
|
82 |
+
max_size=(70, 70),
|
83 |
):
|
84 |
super().__init__()
|
85 |
self.num_queries = num_queries
|
|
|
100 |
self.ln_kv = norm_layer(embed_dim)
|
101 |
|
102 |
self.ln_post = norm_layer(embed_dim)
|
103 |
+
self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
|
104 |
|
105 |
self._set_2d_pos_cache(self.max_size)
|
106 |
|
107 |
+
def _set_2d_pos_cache(self, max_size, device="cpu"):
|
108 |
if is_deepspeed_zero3_enabled():
|
109 |
+
device = "cuda"
|
110 |
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
|
111 |
self.register_buffer("pos_embed", pos_embed, persistent=False)
|
112 |
|
|
|
119 |
|
120 |
def _init_weights(self, m):
|
121 |
if isinstance(m, nn.Linear):
|
122 |
+
trunc_normal_(m.weight, std=0.02)
|
123 |
if isinstance(m, nn.Linear) and m.bias is not None:
|
124 |
nn.init.constant_(m.bias, 0)
|
125 |
elif isinstance(m, nn.LayerNorm):
|
|
|
144 |
for i in range(bs):
|
145 |
tgt_h, tgt_w = tgt_sizes[i]
|
146 |
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
|
147 |
+
key_padding_mask[i, patch_len[i] :] = True
|
148 |
|
149 |
+
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
|
150 |
+
1, 0, 2
|
151 |
+
) # BLD => L * B * D
|
152 |
|
153 |
x = self.kv_proj(x) # B * L * D
|
154 |
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
|
|
|
159 |
self._repeat(q, bs), # Q * B * D
|
160 |
x + pos_embed, # L * B * D + L * B * D
|
161 |
x,
|
162 |
+
key_padding_mask=key_padding_mask,
|
163 |
+
)[0]
|
164 |
# out: Q * B * D
|
165 |
x = out.permute(1, 0, 2) # B * Q * D
|
166 |
|
|
|
173 |
|
174 |
|
175 |
class MultiheadAttention(nn.MultiheadAttention):
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
embed_dim,
|
179 |
+
num_heads,
|
180 |
+
dropout=0.0,
|
181 |
+
bias=True,
|
182 |
+
add_bias_kv=False,
|
183 |
+
add_zero_attn=False,
|
184 |
+
kdim=None,
|
185 |
+
vdim=None,
|
186 |
+
batch_first=False,
|
187 |
+
device=None,
|
188 |
+
dtype=None,
|
189 |
+
):
|
190 |
+
super().__init__(
|
191 |
+
embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim, batch_first, device, dtype
|
192 |
+
)
|
193 |
|
194 |
# rewrite out_proj layer,with nn.Linear
|
195 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, device=device, dtype=dtype)
|
196 |
|
197 |
def forward(
|
198 |
+
self,
|
199 |
+
query: Tensor,
|
200 |
+
key: Tensor,
|
201 |
+
value: Tensor,
|
202 |
+
key_padding_mask: Optional[Tensor] = None,
|
203 |
+
need_weights: bool = True,
|
204 |
+
attn_mask: Optional[Tensor] = None,
|
205 |
+
average_attn_weights: bool = True,
|
206 |
+
is_causal: bool = False,
|
207 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
208 |
+
why_not_fast_path = ""
|
209 |
+
if (
|
210 |
+
(attn_mask is not None and torch.is_floating_point(attn_mask))
|
211 |
+
or (key_padding_mask is not None)
|
212 |
+
and torch.is_floating_point(key_padding_mask)
|
213 |
+
):
|
214 |
why_not_fast_path = "floating-point masks are not supported for fast path."
|
215 |
|
216 |
is_batched = query.dim() == 3
|
|
|
220 |
mask_name="key_padding_mask",
|
221 |
other_type=F._none_or_dtype(attn_mask),
|
222 |
other_name="attn_mask",
|
223 |
+
target_type=query.dtype,
|
224 |
)
|
225 |
|
226 |
attn_mask = _canonical_mask(
|
|
|
232 |
check_other=False,
|
233 |
)
|
234 |
|
|
|
235 |
if not is_batched:
|
236 |
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
237 |
elif query is not key or key is not value:
|
|
|
240 |
# they don't!
|
241 |
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
242 |
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
243 |
+
why_not_fast_path = (
|
244 |
+
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
245 |
+
)
|
246 |
elif self.in_proj_weight is None:
|
247 |
why_not_fast_path = "in_proj_weight was None"
|
248 |
elif query.dtype != self.in_proj_weight.dtype:
|
249 |
# this case will fail anyway, but at least they'll get a useful error message.
|
250 |
+
why_not_fast_path = (
|
251 |
+
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
252 |
+
)
|
253 |
elif self.training:
|
254 |
why_not_fast_path = "training is enabled"
|
255 |
elif (self.num_heads % 2) != 0:
|
|
|
287 |
elif _is_make_fx_tracing():
|
288 |
why_not_fast_path = "we are running make_fx tracing"
|
289 |
elif not all(_check_arg_device(x) for x in tensor_args):
|
290 |
+
why_not_fast_path = (
|
291 |
+
"some Tensor argument's device is neither one of "
|
292 |
+
f"cpu, cuda or {torch.utils.backend_registration._privateuse1_backend_name}"
|
293 |
+
)
|
294 |
elif torch.is_grad_enabled() and any(_arg_requires_grad(x) for x in tensor_args):
|
295 |
+
why_not_fast_path = (
|
296 |
+
"grad is enabled and at least one of query or the "
|
297 |
+
"input/output projection weights or biases requires_grad"
|
298 |
+
)
|
299 |
if not why_not_fast_path:
|
300 |
merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query)
|
301 |
|
|
|
313 |
merged_mask,
|
314 |
need_weights,
|
315 |
average_attn_weights,
|
316 |
+
mask_type,
|
317 |
+
)
|
318 |
|
319 |
any_nested = query.is_nested or key.is_nested or value.is_nested
|
320 |
+
assert not any_nested, (
|
321 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
322 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
323 |
+
)
|
324 |
|
325 |
if self.batch_first and is_batched:
|
326 |
# make sure that the transpose op does not affect the "is" property
|
|
|
332 |
value = key
|
333 |
else:
|
334 |
query, key, value = (x.transpose(1, 0) for x in (query, key, value))
|
335 |
+
|
336 |
if not self._qkv_same_embed_dim:
|
337 |
attn_output, attn_output_weights = self.multi_head_attention_forward(
|
338 |
+
query,
|
339 |
+
key,
|
340 |
+
value,
|
341 |
+
self.embed_dim,
|
342 |
+
self.num_heads,
|
343 |
+
self.in_proj_weight,
|
344 |
+
self.in_proj_bias,
|
345 |
+
self.bias_k,
|
346 |
+
self.bias_v,
|
347 |
+
self.add_zero_attn,
|
348 |
+
self.dropout,
|
349 |
+
self.out_proj.weight,
|
350 |
+
self.out_proj.bias,
|
351 |
training=self.training,
|
352 |
+
key_padding_mask=key_padding_mask,
|
353 |
+
need_weights=need_weights,
|
354 |
attn_mask=attn_mask,
|
355 |
use_separate_proj_weight=True,
|
356 |
+
q_proj_weight=self.q_proj_weight,
|
357 |
+
k_proj_weight=self.k_proj_weight,
|
358 |
v_proj_weight=self.v_proj_weight,
|
359 |
average_attn_weights=average_attn_weights,
|
360 |
+
is_causal=is_causal,
|
361 |
+
)
|
362 |
else:
|
363 |
attn_output, attn_output_weights = self.multi_head_attention_forward(
|
364 |
+
query,
|
365 |
+
key,
|
366 |
+
value,
|
367 |
+
self.embed_dim,
|
368 |
+
self.num_heads,
|
369 |
+
self.in_proj_weight,
|
370 |
+
self.in_proj_bias,
|
371 |
+
self.bias_k,
|
372 |
+
self.bias_v,
|
373 |
+
self.add_zero_attn,
|
374 |
+
self.dropout,
|
375 |
+
self.out_proj.weight,
|
376 |
+
self.out_proj.bias,
|
377 |
training=self.training,
|
378 |
key_padding_mask=key_padding_mask,
|
379 |
need_weights=need_weights,
|
380 |
attn_mask=attn_mask,
|
381 |
average_attn_weights=average_attn_weights,
|
382 |
+
is_causal=is_causal,
|
383 |
+
)
|
384 |
if self.batch_first and is_batched:
|
385 |
return attn_output.transpose(1, 0), attn_output_weights
|
386 |
else:
|
387 |
return attn_output, attn_output_weights
|
388 |
+
|
389 |
def multi_head_attention_forward(
|
390 |
self,
|
391 |
query: Tensor,
|
|
|
415 |
is_causal: bool = False,
|
416 |
) -> Tuple[Tensor, Optional[Tensor]]:
|
417 |
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
418 |
+
|
419 |
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
420 |
+
|
421 |
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
422 |
# is batched, run the computation and before returning squeeze the
|
423 |
# batch dimension so that the output doesn't carry this temporary batch dimension.
|
|
|
428 |
value = value.unsqueeze(1)
|
429 |
if key_padding_mask is not None:
|
430 |
key_padding_mask = key_padding_mask.unsqueeze(0)
|
431 |
+
|
432 |
# set up shape vars
|
433 |
tgt_len, bsz, embed_dim = query.shape
|
434 |
src_len, _, _ = key.shape
|
435 |
+
|
436 |
key_padding_mask = _canonical_mask(
|
437 |
mask=key_padding_mask,
|
438 |
mask_name="key_padding_mask",
|
439 |
other_type=_none_or_dtype(attn_mask),
|
440 |
other_name="attn_mask",
|
441 |
+
target_type=query.dtype,
|
442 |
)
|
443 |
+
|
444 |
if is_causal and attn_mask is None:
|
445 |
raise RuntimeError(
|
446 |
"Need attn_mask if specifying the is_causal hint. "
|
447 |
"You may use the Transformer module method "
|
448 |
"`generate_square_subsequent_mask` to create this mask."
|
449 |
)
|
450 |
+
|
451 |
if is_causal and key_padding_mask is None and not need_weights:
|
452 |
# when we have a kpm or need weights, we need attn_mask
|
453 |
# Otherwise, we use the is_causal hint go as is_causal
|
|
|
462 |
target_type=query.dtype,
|
463 |
check_other=False,
|
464 |
)
|
465 |
+
|
466 |
if key_padding_mask is not None:
|
467 |
# We have the attn_mask, and use that to merge kpm into it.
|
468 |
# Turn off use of is_causal hint, as the merged mask is no
|
469 |
# longer causal.
|
470 |
is_causal = False
|
471 |
+
|
472 |
+
assert (
|
473 |
+
embed_dim == embed_dim_to_check
|
474 |
+
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
475 |
if isinstance(embed_dim, torch.Tensor):
|
476 |
# embed_dim can be a tensor when JIT tracing
|
477 |
+
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
478 |
else:
|
479 |
head_dim = embed_dim // num_heads
|
480 |
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
481 |
if use_separate_proj_weight:
|
482 |
# allow MHA to have different embedding dimensions when separate projection weights are used
|
483 |
+
assert (
|
484 |
+
key.shape[:2] == value.shape[:2]
|
485 |
+
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
486 |
else:
|
487 |
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
488 |
+
|
489 |
#
|
490 |
# compute in-projection
|
491 |
#
|
|
|
501 |
else:
|
502 |
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
503 |
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
|
504 |
+
|
505 |
# prep attention mask
|
506 |
+
|
507 |
if attn_mask is not None:
|
508 |
# ensure attn_mask's dim is 3
|
509 |
if attn_mask.dim() == 2:
|
510 |
correct_2d_size = (tgt_len, src_len)
|
511 |
if attn_mask.shape != correct_2d_size:
|
512 |
+
raise RuntimeError(
|
513 |
+
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
|
514 |
+
)
|
515 |
attn_mask = attn_mask.unsqueeze(0)
|
516 |
elif attn_mask.dim() == 3:
|
517 |
correct_3d_size = (bsz * num_heads, tgt_len, src_len)
|
518 |
if attn_mask.shape != correct_3d_size:
|
519 |
+
raise RuntimeError(
|
520 |
+
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
521 |
+
)
|
522 |
else:
|
523 |
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
524 |
+
|
525 |
# add bias along batch dimension (currently second)
|
526 |
if bias_k is not None and bias_v is not None:
|
527 |
assert static_k is None, "bias cannot be added to static key."
|
|
|
535 |
else:
|
536 |
assert bias_k is None
|
537 |
assert bias_v is None
|
538 |
+
|
539 |
#
|
540 |
# reshape q, k, v for multihead attention and make em batch first
|
541 |
#
|
|
|
544 |
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
545 |
else:
|
546 |
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
547 |
+
assert (
|
548 |
+
static_k.size(0) == bsz * num_heads
|
549 |
+
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
550 |
+
assert (
|
551 |
+
static_k.size(2) == head_dim
|
552 |
+
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
553 |
k = static_k
|
554 |
if static_v is None:
|
555 |
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
556 |
else:
|
557 |
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
558 |
+
assert (
|
559 |
+
static_v.size(0) == bsz * num_heads
|
560 |
+
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
561 |
+
assert (
|
562 |
+
static_v.size(2) == head_dim
|
563 |
+
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
564 |
v = static_v
|
565 |
+
|
566 |
# add zero attention along batch dimension (now first)
|
567 |
if add_zero_attn:
|
568 |
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
|
|
572 |
attn_mask = pad(attn_mask, (0, 1))
|
573 |
if key_padding_mask is not None:
|
574 |
key_padding_mask = pad(key_padding_mask, (0, 1))
|
575 |
+
|
576 |
# update source sequence length after adjustments
|
577 |
src_len = k.size(1)
|
578 |
+
|
579 |
# merge key padding and attention masks
|
580 |
if key_padding_mask is not None:
|
581 |
+
assert key_padding_mask.shape == (
|
582 |
+
bsz,
|
583 |
+
src_len,
|
584 |
+
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
585 |
+
key_padding_mask = (
|
586 |
+
key_padding_mask.view(bsz, 1, 1, src_len)
|
587 |
+
.expand(-1, num_heads, -1, -1)
|
588 |
+
.reshape(bsz * num_heads, 1, src_len)
|
589 |
+
)
|
590 |
if attn_mask is None:
|
591 |
attn_mask = key_padding_mask
|
592 |
else:
|
593 |
attn_mask = attn_mask + key_padding_mask
|
594 |
+
|
595 |
# adjust dropout probability
|
596 |
if not training:
|
597 |
dropout_p = 0.0
|
598 |
+
|
599 |
#
|
600 |
# (deep breath) calculate attention and out projection
|
601 |
#
|
602 |
+
|
603 |
if need_weights:
|
604 |
B, Nt, E = q.shape
|
605 |
q_scaled = q / math.sqrt(E)
|
606 |
+
|
607 |
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
608 |
+
|
609 |
if attn_mask is not None:
|
610 |
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
611 |
else:
|
|
|
613 |
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
614 |
if dropout_p > 0.0:
|
615 |
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
|
616 |
+
|
617 |
attn_output = torch.bmm(attn_output_weights, v)
|
618 |
+
|
619 |
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
620 |
attn_output = self.out_proj(attn_output)
|
621 |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
622 |
+
|
623 |
# optionally average attention weights over heads
|
624 |
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
625 |
if average_attn_weights:
|
626 |
attn_output_weights = attn_output_weights.mean(dim=1)
|
627 |
+
|
628 |
if not is_batched:
|
629 |
# squeeze the output if input was unbatched
|
630 |
attn_output = attn_output.squeeze(1)
|
|
|
639 |
attn_mask = attn_mask.unsqueeze(0)
|
640 |
else:
|
641 |
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
642 |
+
|
643 |
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
644 |
k = k.view(bsz, num_heads, src_len, head_dim)
|
645 |
v = v.view(bsz, num_heads, src_len, head_dim)
|
646 |
+
|
647 |
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
648 |
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
649 |
+
|
650 |
attn_output = self.out_proj(attn_output)
|
651 |
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
652 |
if not is_batched:
|
|
|
655 |
return attn_output, None
|
656 |
|
657 |
|
658 |
+
def _mha_shape_check(
|
659 |
+
query: Tensor,
|
660 |
+
key: Tensor,
|
661 |
+
value: Tensor,
|
662 |
+
key_padding_mask: Optional[Tensor],
|
663 |
+
attn_mask: Optional[Tensor],
|
664 |
+
num_heads: int,
|
665 |
+
):
|
666 |
# Verifies the expected shape for `query, `key`, `value`, `key_padding_mask` and `attn_mask`
|
667 |
# and returns if the input is batched or not.
|
668 |
# Raises an error if `query` is not 2-D (unbatched) or 3-D (batched) tensor.
|
|
|
671 |
if query.dim() == 3:
|
672 |
# Batched Inputs
|
673 |
is_batched = True
|
674 |
+
assert key.dim() == 3 and value.dim() == 3, (
|
675 |
+
"For batched (3-D) `query`, expected `key` and `value` to be 3-D"
|
676 |
+
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
|
677 |
+
)
|
678 |
if key_padding_mask is not None:
|
679 |
+
assert key_padding_mask.dim() == 2, (
|
680 |
+
"For batched (3-D) `query`, expected `key_padding_mask` to be `None` or 2-D"
|
681 |
+
f" but found {key_padding_mask.dim()}-D tensor instead"
|
682 |
+
)
|
683 |
if attn_mask is not None:
|
684 |
+
assert attn_mask.dim() in (2, 3), (
|
685 |
+
"For batched (3-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
|
686 |
+
f" but found {attn_mask.dim()}-D tensor instead"
|
687 |
+
)
|
688 |
elif query.dim() == 2:
|
689 |
# Unbatched Inputs
|
690 |
is_batched = False
|
691 |
+
assert key.dim() == 2 and value.dim() == 2, (
|
692 |
+
"For unbatched (2-D) `query`, expected `key` and `value` to be 2-D"
|
693 |
+
f" but found {key.dim()}-D and {value.dim()}-D tensors respectively"
|
694 |
+
)
|
695 |
|
696 |
if key_padding_mask is not None:
|
697 |
+
assert key_padding_mask.dim() == 1, (
|
698 |
+
"For unbatched (2-D) `query`, expected `key_padding_mask` to be `None` or 1-D"
|
699 |
+
f" but found {key_padding_mask.dim()}-D tensor instead"
|
700 |
+
)
|
701 |
|
702 |
if attn_mask is not None:
|
703 |
+
assert attn_mask.dim() in (2, 3), (
|
704 |
+
"For unbatched (2-D) `query`, expected `attn_mask` to be `None`, 2-D or 3-D"
|
705 |
+
f" but found {attn_mask.dim()}-D tensor instead"
|
706 |
+
)
|
707 |
if attn_mask.dim() == 3:
|
708 |
expected_shape = (num_heads, query.shape[0], key.shape[0])
|
709 |
+
assert (
|
710 |
+
attn_mask.shape == expected_shape
|
711 |
+
), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
|
712 |
else:
|
713 |
raise AssertionError(
|
714 |
+
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
|
715 |
+
)
|
716 |
|
717 |
return is_batched
|
718 |
|
719 |
|
720 |
def _canonical_mask(
|
721 |
+
mask: Optional[Tensor],
|
722 |
+
mask_name: str,
|
723 |
+
other_type: Optional[DType],
|
724 |
+
other_name: str,
|
725 |
+
target_type: DType,
|
726 |
+
check_other: bool = True,
|
727 |
) -> Optional[Tensor]:
|
|
|
728 |
if mask is not None:
|
729 |
_mask_dtype = mask.dtype
|
730 |
_mask_is_float = torch.is_floating_point(mask)
|
731 |
if _mask_dtype != torch.bool and not _mask_is_float:
|
732 |
+
raise AssertionError(f"only bool and floating types of {mask_name} are supported")
|
|
|
733 |
if check_other and other_type is not None:
|
734 |
if _mask_dtype != other_type:
|
735 |
warnings.warn(
|
|
|
737 |
"is deprecated. Use same type for both instead."
|
738 |
)
|
739 |
if not _mask_is_float:
|
740 |
+
mask = torch.zeros_like(mask, dtype=target_type).masked_fill_(mask, float("-inf"))
|
|
|
|
|
|
|
741 |
return mask
|
742 |
|
743 |
|
|
|
748 |
return input.dtype
|
749 |
raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
|
750 |
|
751 |
+
|
752 |
def _in_projection_packed(
|
753 |
q: Tensor,
|
754 |
k: Tensor,
|
|
|
855 |
assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
|
856 |
assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
|
857 |
assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
|
858 |
+
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
tokenization_minicpmv_fast.py
CHANGED
@@ -40,7 +40,7 @@ class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
|
|
40 |
@property
|
41 |
def slice_start_id(self):
|
42 |
return self.convert_tokens_to_ids(self.slice_start)
|
43 |
-
|
44 |
@property
|
45 |
def slice_end_id(self):
|
46 |
return self.convert_tokens_to_ids(self.slice_end)
|
@@ -48,14 +48,14 @@ class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
|
|
48 |
@property
|
49 |
def im_id_start_id(self):
|
50 |
return self.convert_tokens_to_ids(self.im_id_start)
|
51 |
-
|
52 |
@property
|
53 |
def im_id_end_id(self):
|
54 |
return self.convert_tokens_to_ids(self.im_id_end)
|
55 |
-
|
56 |
@property
|
57 |
def newline_id(self):
|
58 |
-
return self.convert_tokens_to_ids(
|
59 |
|
60 |
@staticmethod
|
61 |
def escape(text: str) -> str:
|
@@ -63,4 +63,4 @@ class MiniCPMVTokenizerFast(Qwen2TokenizerFast):
|
|
63 |
|
64 |
@staticmethod
|
65 |
def unescape(text: str) -> str:
|
66 |
-
return text
|
|
|
40 |
@property
|
41 |
def slice_start_id(self):
|
42 |
return self.convert_tokens_to_ids(self.slice_start)
|
43 |
+
|
44 |
@property
|
45 |
def slice_end_id(self):
|
46 |
return self.convert_tokens_to_ids(self.slice_end)
|
|
|
48 |
@property
|
49 |
def im_id_start_id(self):
|
50 |
return self.convert_tokens_to_ids(self.im_id_start)
|
51 |
+
|
52 |
@property
|
53 |
def im_id_end_id(self):
|
54 |
return self.convert_tokens_to_ids(self.im_id_end)
|
55 |
+
|
56 |
@property
|
57 |
def newline_id(self):
|
58 |
+
return self.convert_tokens_to_ids("\n")
|
59 |
|
60 |
@staticmethod
|
61 |
def escape(text: str) -> str:
|
|
|
63 |
|
64 |
@staticmethod
|
65 |
def unescape(text: str) -> str:
|
66 |
+
return text
|