jadechoghari commited on
Commit
cc3d147
1 Parent(s): 2cc1492

add initial files

Browse files
README.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - shenxq/OneVision
4
+ - shenxq/VideoChat2
5
+ base_model:
6
+ - Vision-CAIR/LongVU_Qwen2_7B_img
7
+ pipeline_tag: video-text-to-text
8
+ model-index:
9
+ - name: llava-onevision-qwen-7b-ov
10
+ results:
11
+ - task:
12
+ type: multimodal
13
+ dataset:
14
+ name: EgoSchema
15
+ type: egoschema
16
+ metrics:
17
+ - type: accuracy
18
+ value: 67.6
19
+ name: accuracy
20
+ verified: true
21
+ - task:
22
+ type: multimodal
23
+ dataset:
24
+ name: MLVU
25
+ type: mlvu
26
+ metrics:
27
+ - type: accuracy
28
+ value: 65.4
29
+ name: accuracy
30
+ verified: true
31
+ - task:
32
+ type: multimodal
33
+ dataset:
34
+ name: MVBench
35
+ type: mvbench
36
+ metrics:
37
+ - type: accuracy
38
+ value: 66.9
39
+ name: accuracy
40
+ verified: true
41
+ - task:
42
+ type: multimodal
43
+ dataset:
44
+ name: VideoMME
45
+ type: videomme
46
+ metrics:
47
+ - type: accuracy
48
+ value: 60.6
49
+ name: accuracy
50
+ verified: true
51
+ ---
52
+ # LongVU
53
+
54
+ This repository contains the model based on Qwen2-7B as presented in [LongVU: Spatiotemporal Adaptive Compression for Long Video-Language Understanding](https://huggingface.co/papers/2410.17434).
55
+
56
+ Play with the model on the [HF demo](https://huggingface.co/spaces/Vision-CAIR/LongVU).
57
+
58
+ <div align="left">
59
+ <a href='https://vision-cair.github.io/LongVU'><img src="https://longvu.s3.amazonaws.com/assets/demo.gif" alt="Demo GIF" style="width: 100%; max-width: 650px;"></a>
60
+ </div>
61
+
62
+ # Use
63
+
64
+ We provide the simple generation process for using our model. For more details, you could refer to [Github](https://github.com/Vision-CAIR/LongVU)
65
+
66
+ ```python
67
+ # git clone https://github.com/Vision-CAIR/LongVU
68
+ import numpy as np
69
+ import torch
70
+ from longvu.builder import load_pretrained_model
71
+ from longvu.constants import (
72
+ DEFAULT_IMAGE_TOKEN,
73
+ IMAGE_TOKEN_INDEX,
74
+ )
75
+ from longvu.conversation import conv_templates, SeparatorStyle
76
+ from longvu.mm_datautils import (
77
+ KeywordsStoppingCriteria,
78
+ process_images,
79
+ tokenizer_image_token,
80
+ )
81
+ from decord import cpu, VideoReader
82
+
83
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
84
+ "./checkpoints/longvu_qwen", None, "cambrian_qwen",
85
+ )
86
+
87
+ model.eval()
88
+ video_path = "./examples/video1.mp4"
89
+ qs = "Describe this video in detail"
90
+
91
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
92
+ fps = float(vr.get_avg_fps())
93
+ frame_indices = np.array([i for i in range(0, len(vr), round(fps),)])
94
+ video = []
95
+ for frame_index in frame_indices:
96
+ img = vr[frame_index].asnumpy()
97
+ video.append(img)
98
+ video = np.stack(video)
99
+ image_sizes = [video[0].shape[:2]]
100
+ video = process_images(video, image_processor, model.config)
101
+ video = [item.unsqueeze(0) for item in video]
102
+
103
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
104
+ conv = conv_templates["qwen"].copy()
105
+ conv.append_message(conv.roles[0], qs)
106
+ conv.append_message(conv.roles[1], None)
107
+ prompt = conv.get_prompt()
108
+
109
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
110
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
111
+ keywords = [stop_str]
112
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
113
+ with torch.inference_mode():
114
+ output_ids = model.generate(
115
+ input_ids,
116
+ images=video,
117
+ image_sizes=image_sizes,
118
+ do_sample=False,
119
+ temperature=0.2,
120
+ max_new_tokens=128,
121
+ use_cache=True,
122
+ stopping_criteria=[stopping_criteria],
123
+ )
124
+ pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
125
+ ```
126
+
127
+ # Citation
128
+
129
+ ```
130
+ @article{shen2024longvu,
131
+ title={LongVU: Spatiotemporal Adaptive Compression for Long Video-Language Understanding},
132
+ author={Shen, Xiaoqian and Xiong, Yunyang and Zhao, Changsheng and Wu, Lemeng and Chen, Jun and Zhu, Chenchen and Liu, Zechun and Xiao, Fanyi and Varadarajan, Balakrishnan and Bordes, Florian and Liu, Zhuang and Xu, Hu and J. Kim, Hyunwoo and Soran, Bilge and Krishnamoorthi, Raghuraman and Elhoseiny, Mohamed and Chandra, Vikas},
133
+ journal={arXiv:2410.17434},
134
+ year={2024}
135
+ }
136
+ ```
cambrian_arch.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import random
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ # define the constants
25
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
26
+ WORKER_HEART_BEAT_INTERVAL = 15
27
+
28
+ LOGDIR = "."
29
+
30
+ # Model Constants
31
+ IGNORE_INDEX = -100
32
+ IMAGE_TOKEN_INDEX = -200
33
+ DEFAULT_IMAGE_TOKEN = "<image>"
34
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
35
+ DEFAULT_IM_START_TOKEN = "<im_start>"
36
+ DEFAULT_IM_END_TOKEN = "<im_end>"
37
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
38
+
39
+ from .multimodal_encoder_builder import build_vision_tower_aux_list
40
+ from .multimodal_projector_builder import build_vision_projector
41
+ from .vision_sampler import VisionTokenSampler
42
+
43
+ IS_XLA_AVAILABLE = False
44
+
45
+
46
+ class CambrianMetaModel:
47
+
48
+ def __init__(self, config):
49
+ super(CambrianMetaModel, self).__init__(config)
50
+
51
+ if hasattr(config, "mm_vision_tower_aux_list"):
52
+
53
+ projector_type = getattr(config, "mm_projector_type", "linear")
54
+ if projector_type == "sva":
55
+
56
+ vision_hidden_size = config.vision_hidden_size
57
+ num_query_group = config.num_query_group
58
+ query_num_list = config.query_num_list
59
+ connector_only = config.connector_only
60
+ connector_depth = config.connector_depth
61
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
62
+ config, delay_load=True
63
+ )
64
+ self.mm_projector = nn.Sequential(
65
+ nn.Linear(vision_hidden_size * num_query_group, config.hidden_size),
66
+ nn.GELU(),
67
+ nn.Linear(config.hidden_size, config.hidden_size),
68
+ )
69
+
70
+ image_token_len = config.image_token_len
71
+ vision_tower_aux_token_len_list = (
72
+ self.config.mm_vision_tower_aux_token_len_list
73
+ )
74
+ cross_att_token_len_list = [
75
+ int(vision_tower_aux_token_len**0.5) // int(image_token_len**0.5)
76
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
77
+ ]
78
+
79
+ for aux_i, vision_tower_aux in enumerate(self.vision_tower_aux_list):
80
+ setattr(
81
+ self,
82
+ "mm_projector_aux_{}".format(aux_i),
83
+ nn.Sequential(
84
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
85
+ nn.GELU(),
86
+ nn.Linear(vision_hidden_size, vision_hidden_size),
87
+ nn.LayerNorm(vision_hidden_size),
88
+ ),
89
+ )
90
+
91
+ for query_group_i in range(num_query_group):
92
+ cross_att_token_len_list = [
93
+ int(vision_tower_aux_token_len**0.5)
94
+ // int(query_num_list[query_group_i] ** 0.5)
95
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
96
+ ]
97
+ setattr(
98
+ self,
99
+ "vision_sampler_{}".format(query_group_i),
100
+ VisionTokenSampler(
101
+ vision_hidden_size,
102
+ vision_hidden_size,
103
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
104
+ cross_att_token_len_list,
105
+ vision_hidden_size,
106
+ connector_depth,
107
+ ),
108
+ )
109
+
110
+ if not connector_only:
111
+ num_of_vision_sampler_layers = (
112
+ config.num_of_vision_sampler_layers
113
+ ) = config.num_of_vision_sampler_layers
114
+ config.start_of_vision_sampler_layers = (
115
+ config.start_of_vision_sampler_layers
116
+ )
117
+ config.stride_of_vision_sampler_layers = (
118
+ config.stride_of_vision_sampler_layers
119
+ )
120
+ cross_att_token_len_list = [
121
+ int(vision_tower_aux_token_len**0.5)
122
+ // int(image_token_len**0.5)
123
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
124
+ ]
125
+ self.vision_sampler_layers = nn.ModuleList(
126
+ [
127
+ VisionTokenSampler(
128
+ config.hidden_size,
129
+ vision_hidden_size,
130
+ [vision_hidden_size] * len(self.vision_tower_aux_list),
131
+ cross_att_token_len_list,
132
+ vision_hidden_size,
133
+ 1,
134
+ )
135
+ for layer_idx in range(0, num_of_vision_sampler_layers)
136
+ ]
137
+ )
138
+
139
+ self.vision_query = nn.Parameter(
140
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
141
+ )
142
+
143
+ self.image_newline = nn.Parameter(
144
+ torch.empty(config.hidden_size, dtype=self.dtype)
145
+ )
146
+
147
+ self.frame_pos = torch.stack(
148
+ [
149
+ 1
150
+ / torch.pow(
151
+ torch.tensor(10000),
152
+ torch.tensor(2 * (hid_j // 2) / config.hidden_size),
153
+ )
154
+ for hid_j in range(config.hidden_size)
155
+ ]
156
+ )
157
+
158
+ else:
159
+ self.vision_tower_aux_list = build_vision_tower_aux_list(
160
+ config, delay_load=True
161
+ )
162
+ config.mm_hidden_size = sum(
163
+ [
164
+ vision_tower_aux.hidden_size
165
+ for vision_tower_aux in self.vision_tower_aux_list
166
+ ]
167
+ )
168
+ self.mm_projector = build_vision_projector(config)
169
+ self.image_newline = nn.Parameter(
170
+ torch.empty(config.hidden_size, dtype=self.dtype)
171
+ )
172
+
173
+ def get_frame_pos(self, time_range):
174
+ frame_pos = self.frame_pos.reshape(1, -1) * time_range.reshape(-1, 1).to(
175
+ self.frame_pos.device
176
+ )
177
+ frame_pos[:, 0::2] = torch.sin(frame_pos[:, 0::2])
178
+ frame_pos[:, 1::2] = torch.cos(frame_pos[:, 0::2])
179
+ frame_pos = frame_pos.unsqueeze(1)
180
+ return frame_pos
181
+
182
+ # def get_vision_tower(self):
183
+ # vision_tower = getattr(self, 'vision_tower', None)
184
+ # if type(vision_tower) is list:
185
+ # vision_tower = vision_tower[0]
186
+ # return vision_tower
187
+
188
+ def get_vision_tower_aux_list(self):
189
+ vision_tower_aux_list = getattr(self, "vision_tower_aux_list", None)
190
+ return vision_tower_aux_list
191
+
192
+ def initialize_vision_modules(self, model_args, fsdp=None):
193
+ # vision_tower = model_args.vision_tower
194
+ num_query_group = model_args.num_query_group
195
+ query_num_list = model_args.query_num_list
196
+ vision_hidden_size = model_args.vision_hidden_size
197
+ vision_tower_aux_list = model_args.vision_tower_aux_list
198
+ vision_tower_aux_token_len_list = model_args.vision_tower_aux_token_len_list
199
+ image_token_len = model_args.image_token_len
200
+ mm_vision_select_layer = model_args.mm_vision_select_layer
201
+ mm_vision_select_feature = model_args.mm_vision_select_feature
202
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
203
+ connector_only = model_args.connector_only
204
+ connector_depth = model_args.connector_depth
205
+
206
+ # self.config.mm_vision_tower = vision_tower
207
+ self.config.image_token_len = image_token_len
208
+ self.config.num_query_group = num_query_group
209
+ self.config.query_num_list = query_num_list
210
+ assert num_query_group == len(query_num_list)
211
+ self.config.connector_depth = connector_depth
212
+ self.config.mm_vision_tower_aux_list = vision_tower_aux_list
213
+ self.config.mm_vision_tower_aux_token_len_list = vision_tower_aux_token_len_list
214
+ self.config.connector_only = connector_only
215
+ self.config.highres_connect = model_args.highres_connect
216
+ self.config.highres = model_args.highres
217
+ self.config.frame_pos = model_args.frame_pos
218
+ self.config.lowres_token = model_args.lowres_token
219
+ self.config.connect_layer = model_args.connect_layer
220
+ self.config.dino_threshold = getattr(model_args, "dino_threshold", 0.83)
221
+ self.config.drop_threshold = getattr(model_args, "drop_threshold", 0.6)
222
+ self.config.is_image_newline = getattr(model_args, "is_image_newline", True)
223
+
224
+ if self.get_vision_tower_aux_list() is None:
225
+ vision_tower_aux_list = build_vision_tower_aux_list(model_args)
226
+ if model_args.unfreeze_mm_vision_tower:
227
+ self.vision_tower_aux_list = nn.ModuleList(vision_tower_aux_list)
228
+ else:
229
+ self.vision_tower_aux_list = vision_tower_aux_list
230
+ else:
231
+ vision_tower_aux_list = self.vision_tower_aux_list
232
+ for vision_tower_aux in vision_tower_aux_list:
233
+ vision_tower_aux.load_model()
234
+
235
+ self.config.use_mm_proj = True
236
+ self.config.mm_projector_type = getattr(
237
+ model_args, "mm_projector_type", "linear"
238
+ )
239
+ self.config.vision_hidden_size = vision_hidden_size
240
+ self.config.mm_vision_select_layer = mm_vision_select_layer
241
+ self.config.mm_vision_select_feature = mm_vision_select_feature
242
+
243
+ if getattr(self, "mm_projector", None) is None:
244
+
245
+ if self.config.mm_projector_type == "sva":
246
+ self.mm_projector = nn.Sequential(
247
+ nn.Linear(
248
+ vision_hidden_size * num_query_group, self.config.hidden_size
249
+ ),
250
+ nn.GELU(),
251
+ nn.Linear(self.config.hidden_size, self.config.hidden_size),
252
+ )
253
+ for aux_i, vision_tower_aux in enumerate(vision_tower_aux_list):
254
+ setattr(
255
+ self,
256
+ "mm_projector_aux_{}".format(aux_i),
257
+ nn.Sequential(
258
+ nn.Linear(vision_tower_aux.hidden_size, vision_hidden_size),
259
+ nn.GELU(),
260
+ nn.Linear(vision_hidden_size, vision_hidden_size),
261
+ nn.LayerNorm(vision_hidden_size),
262
+ ),
263
+ )
264
+
265
+ # vision sampler for each group of query as the connector before the LLM
266
+ for query_group_i in range(num_query_group):
267
+ cross_att_token_len_list = [
268
+ int(vision_tower_aux_token_len**0.5)
269
+ // int(query_num_list[query_group_i] ** 0.5)
270
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
271
+ ]
272
+ setattr(
273
+ self,
274
+ "vision_sampler_{}".format(query_group_i),
275
+ VisionTokenSampler(
276
+ vision_hidden_size,
277
+ vision_hidden_size,
278
+ [vision_hidden_size] * len(vision_tower_aux_list),
279
+ cross_att_token_len_list,
280
+ vision_hidden_size,
281
+ connector_depth,
282
+ ),
283
+ )
284
+
285
+ # sampler layers within LLM
286
+ if not connector_only:
287
+ num_of_vision_sampler_layers = (
288
+ self.config.num_of_vision_sampler_layers
289
+ ) = model_args.num_of_vision_sampler_layers
290
+ self.config.start_of_vision_sampler_layers = (
291
+ model_args.start_of_vision_sampler_layers
292
+ )
293
+ self.config.stride_of_vision_sampler_layers = (
294
+ model_args.stride_of_vision_sampler_layers
295
+ )
296
+ cross_att_token_len_list = [
297
+ int(vision_tower_aux_token_len**0.5)
298
+ // int(image_token_len**0.5)
299
+ for vision_tower_aux_token_len in vision_tower_aux_token_len_list
300
+ ]
301
+ self.vision_sampler_layers = nn.ModuleList(
302
+ [
303
+ VisionTokenSampler(
304
+ self.config.hidden_size,
305
+ vision_hidden_size,
306
+ [vision_hidden_size] * len(vision_tower_aux_list),
307
+ cross_att_token_len_list,
308
+ vision_hidden_size,
309
+ 1,
310
+ )
311
+ for layer_idx in range(0, num_of_vision_sampler_layers)
312
+ ]
313
+ )
314
+ vision_embed_std = 1 / torch.sqrt(
315
+ torch.tensor(vision_hidden_size, dtype=self.dtype)
316
+ )
317
+ self.vision_query = nn.Parameter(
318
+ torch.randn((num_query_group, vision_hidden_size), dtype=self.dtype)
319
+ * vision_embed_std
320
+ )
321
+
322
+ embed_std = 1 / torch.sqrt(
323
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
324
+ )
325
+ self.image_newline = nn.Parameter(
326
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
327
+ )
328
+
329
+ else:
330
+ self.config.mm_hidden_size = sum(
331
+ [
332
+ vision_tower_aux.hidden_size
333
+ for vision_tower_aux in vision_tower_aux_list
334
+ ]
335
+ )
336
+ self.mm_projector = build_vision_projector(self.config)
337
+ embed_std = 1 / torch.sqrt(
338
+ torch.tensor(self.config.hidden_size, dtype=self.dtype)
339
+ )
340
+ self.image_newline = nn.Parameter(
341
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
342
+ )
343
+ else:
344
+ # In case it is frozen by LoRA
345
+ for p in self.mm_projector.parameters():
346
+ p.requires_grad = True
347
+
348
+ if pretrain_mm_mlp_adapter is not None:
349
+ mm_projector_weights = torch.load(
350
+ pretrain_mm_mlp_adapter, map_location="cpu"
351
+ )
352
+
353
+ def get_w(weights, keyword):
354
+ return {
355
+ k.split(keyword + ".")[1]: v
356
+ for k, v in weights.items()
357
+ if keyword + "." in k
358
+ }
359
+
360
+ self.mm_projector.load_state_dict(
361
+ get_w(mm_projector_weights, "mm_projector"), strict=True
362
+ )
363
+
364
+ if self.config.mm_projector_type == "sva":
365
+ for aux_i in range(len(vision_tower_aux_list)):
366
+ getattr(self, "mm_projector_aux_{}".format(aux_i)).load_state_dict(
367
+ get_w(
368
+ mm_projector_weights, "mm_projector_aux_{}".format(aux_i)
369
+ ),
370
+ strict=True,
371
+ )
372
+
373
+ for query_group_i in range(num_query_group):
374
+ getattr(
375
+ self, "vision_sampler_{}".format(query_group_i)
376
+ ).load_state_dict(
377
+ get_w(
378
+ mm_projector_weights,
379
+ "vision_sampler_{}".format(query_group_i),
380
+ ),
381
+ strict=True,
382
+ )
383
+
384
+ if not connector_only:
385
+ self.vision_sampler_layers.load_state_dict(
386
+ get_w(mm_projector_weights, "vision_sampler_layers"),
387
+ strict=True,
388
+ )
389
+ self.vision_query.data = mm_projector_weights["model.vision_query"]
390
+ self.image_newline.data = mm_projector_weights["model.image_newline"]
391
+
392
+
393
+ def unmask_attention_mask(mask, original_size):
394
+ original_w, original_h = original_size
395
+ cur_h, cur_w = mask.shape[1:3]
396
+
397
+ original_aspect_ratio = original_w / original_h
398
+ current_aspect_ratio = cur_w / cur_h
399
+
400
+ if original_aspect_ratio > current_aspect_ratio:
401
+ scale_factor = cur_w / original_w
402
+ new_height = int(original_h * scale_factor)
403
+ padding = (cur_h - new_height) // 2
404
+ if padding > 0:
405
+ mask[:, :padding, :] = 0
406
+ mask[:, -padding:, :] = 0
407
+ return mask
408
+ else:
409
+ scale_factor = cur_h / original_h
410
+ new_width = int(original_w * scale_factor)
411
+ padding = (cur_w - new_width) // 2
412
+ if padding > 0:
413
+ mask[:, :, :padding] = 0
414
+ mask[:, :, -padding:] = 0
415
+ return mask
416
+
417
+
418
+ def unpad_image(tensor, original_size):
419
+ """
420
+ Unpads a PyTorch tensor of a padded and resized image.
421
+
422
+ Args:
423
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
424
+ original_size (tuple): The original size of the image (height, width).
425
+
426
+ Returns:
427
+ torch.Tensor: The unpadded image tensor.
428
+ """
429
+ original_width, original_height = original_size
430
+ current_height, current_width = tensor.shape[1:3]
431
+
432
+ original_aspect_ratio = original_width / original_height
433
+ current_aspect_ratio = current_width / current_height
434
+
435
+ if original_aspect_ratio > current_aspect_ratio:
436
+ scale_factor = current_width / original_width
437
+ new_height = int(original_height * scale_factor)
438
+ padding = (current_height - new_height) // 2
439
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
440
+ # if 0 in unpadded_tensor.shape:
441
+ # print(f"scale_factor: {scale_factor}, new_height: {new_height}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
442
+ else:
443
+ scale_factor = current_height / original_height
444
+ new_width = int(original_width * scale_factor)
445
+ padding = (current_width - new_width) // 2
446
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
447
+ # if 0 in unpadded_tensor.shape:
448
+ # print(f"scale_factor: {scale_factor}, new_width: {new_width}, padding: {padding}, original_width: {original_width}, original_height: {original_height}")
449
+
450
+ return unpadded_tensor
451
+
452
+
453
+ class CambrianMetaForCausalLM(ABC):
454
+
455
+ @abstractmethod
456
+ def get_model(self):
457
+ pass
458
+
459
+ # def get_vision_tower(self):
460
+ # return self.get_model().get_vision_tower()
461
+
462
+ def get_vision_tower_aux_list(self):
463
+ return self.get_model().get_vision_tower_aux_list()
464
+
465
+ def rearrange_vision_tower_features_train(
466
+ self,
467
+ vision_tower_aux_feature_list,
468
+ vision_tower_aux_attention_masks_list,
469
+ query_side_len,
470
+ ):
471
+ vision_tower_aux_feature_rearranged_list = []
472
+ vision_tower_aux_attention_masks_rearranged_list = []
473
+ bs = vision_tower_aux_feature_list[0].shape[0]
474
+ for vision_tower_aux_feature, vision_tower_aux_attention_masks in zip(
475
+ vision_tower_aux_feature_list, vision_tower_aux_attention_masks_list
476
+ ):
477
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
478
+ assert (aux_height // query_side_len) * query_side_len == aux_height
479
+
480
+ reduce_factor = aux_height // query_side_len
481
+ vision_tower_aux_feature_rearranged = vision_tower_aux_feature.view(
482
+ bs, query_side_len, reduce_factor, query_side_len, reduce_factor, -1
483
+ )
484
+ vision_tower_aux_feature_rearranged = (
485
+ vision_tower_aux_feature_rearranged.permute(0, 1, 3, 2, 4, 5)
486
+ .contiguous()
487
+ .flatten(0, 2)
488
+ .flatten(1, 2)
489
+ )
490
+
491
+ vision_tower_aux_attention_masks_rearranged = (
492
+ vision_tower_aux_attention_masks.view(
493
+ bs * query_side_len * query_side_len, reduce_factor * reduce_factor
494
+ )
495
+ )
496
+
497
+ vision_tower_aux_feature_rearranged_list.append(
498
+ vision_tower_aux_feature_rearranged
499
+ )
500
+ vision_tower_aux_attention_masks_rearranged_list.append(
501
+ vision_tower_aux_attention_masks_rearranged
502
+ )
503
+ return (
504
+ vision_tower_aux_feature_rearranged_list,
505
+ vision_tower_aux_attention_masks_rearranged_list,
506
+ )
507
+
508
+ def rearrange_vision_tower_features_inference(
509
+ self, vision_tower_aux_feature_list, query_side_len, image_sizes, unpad=False
510
+ ):
511
+ vision_tower_aux_feature_rearranged_list = []
512
+ vision_tower_aux_attention_masks_rearranged_list = []
513
+ bs = vision_tower_aux_feature_list[0].shape[0]
514
+ for vision_tower_aux_feature in vision_tower_aux_feature_list:
515
+ aux_height = aux_width = int(vision_tower_aux_feature.shape[1] ** 0.5)
516
+ assert (aux_height // query_side_len) * query_side_len == aux_height
517
+
518
+ reduce_factor = aux_height // query_side_len
519
+
520
+ vision_tower_aux_feature_rearranged = []
521
+ vision_tower_aux_attention_masks_rearranged = []
522
+ for batch_i in range(bs):
523
+ image_size = image_sizes[batch_i]
524
+ cur_vision_tower_aux_feature = vision_tower_aux_feature[batch_i]
525
+
526
+ cur_vision_tower_aux_attention_masks_rearranged = torch.ones(
527
+ (1, aux_height, aux_width),
528
+ dtype=torch.bool,
529
+ device=cur_vision_tower_aux_feature.device,
530
+ )
531
+ cur_vision_tower_aux_feature_rearranged = (
532
+ cur_vision_tower_aux_feature.view(
533
+ 1,
534
+ query_side_len,
535
+ reduce_factor,
536
+ query_side_len,
537
+ reduce_factor,
538
+ -1,
539
+ )
540
+ )
541
+ cur_vision_tower_aux_feature_rearranged = (
542
+ cur_vision_tower_aux_feature_rearranged.permute(
543
+ 0, 1, 3, 2, 4, 5
544
+ ).contiguous()
545
+ )
546
+ if unpad:
547
+ cur_vision_tower_aux_feature_rearranged = unpad_image(
548
+ cur_vision_tower_aux_feature_rearranged, image_size
549
+ )
550
+ cur_vision_tower_aux_feature_rearranged = (
551
+ cur_vision_tower_aux_feature_rearranged.flatten(0, 2).flatten(1, 2)
552
+ ) # query_side_len*query_side_len X reduce_factor*reduce_factor X C
553
+
554
+ cur_vision_tower_aux_attention_masks_rearranged = unmask_attention_mask(
555
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
556
+ )
557
+ cur_vision_tower_aux_attention_masks_rearranged = (
558
+ cur_vision_tower_aux_attention_masks_rearranged.view(
559
+ 1, query_side_len, reduce_factor, query_side_len, reduce_factor
560
+ )
561
+ .permute(0, 1, 3, 2, 4)
562
+ .contiguous()
563
+ )
564
+ if unpad:
565
+ cur_vision_tower_aux_attention_masks_rearranged = unpad_image(
566
+ cur_vision_tower_aux_attention_masks_rearranged, image_size
567
+ )
568
+ cur_vision_tower_aux_attention_masks_rearranged = (
569
+ cur_vision_tower_aux_attention_masks_rearranged.flatten(
570
+ 0, 2
571
+ ).flatten(1, 2)
572
+ )
573
+
574
+ cur_vision_tower_aux_attention_masks_rearranged[
575
+ cur_vision_tower_aux_attention_masks_rearranged.sum(-1) == 0
576
+ ] = True
577
+
578
+ vision_tower_aux_feature_rearranged.append(
579
+ cur_vision_tower_aux_feature_rearranged
580
+ )
581
+ vision_tower_aux_attention_masks_rearranged.append(
582
+ cur_vision_tower_aux_attention_masks_rearranged
583
+ )
584
+
585
+ vision_tower_aux_feature_rearranged = torch.cat(
586
+ vision_tower_aux_feature_rearranged, 0
587
+ )
588
+ vision_tower_aux_attention_masks_rearranged = torch.cat(
589
+ vision_tower_aux_attention_masks_rearranged, 0
590
+ )
591
+
592
+ vision_tower_aux_feature_rearranged_list.append(
593
+ vision_tower_aux_feature_rearranged
594
+ )
595
+ vision_tower_aux_attention_masks_rearranged_list.append(
596
+ vision_tower_aux_attention_masks_rearranged
597
+ )
598
+
599
+ return (
600
+ vision_tower_aux_feature_rearranged_list,
601
+ vision_tower_aux_attention_masks_rearranged_list,
602
+ )
603
+
604
+ def encode_images(self, image_aux_list, encode_type=None):
605
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
606
+ image_aux_features_list = []
607
+ chunk_size = 64
608
+ if encode_type == "dino":
609
+ image_aux = image_aux_list[-1]
610
+ vision_tower_aux = vision_tower_aux_list[-1]
611
+ if image_aux.shape[0] > chunk_size:
612
+ image_aux_features_chunks = []
613
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
614
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
615
+ chunk = image_aux[start_idx:end_idx]
616
+ image_aux_features_chunk = vision_tower_aux(chunk)
617
+ image_aux_features_chunks.append(image_aux_features_chunk)
618
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
619
+ else:
620
+ image_aux_features = vision_tower_aux(image_aux)
621
+ return image_aux_features
622
+ elif encode_type == "siglip":
623
+ image_aux = image_aux_list[0]
624
+ vision_tower_aux = vision_tower_aux_list[0]
625
+ if image_aux.shape[0] > chunk_size:
626
+ image_aux_features_chunks = []
627
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
628
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
629
+ chunk = image_aux[start_idx:end_idx]
630
+ image_aux_features_chunk = vision_tower_aux(chunk)
631
+ image_aux_features_chunks.append(image_aux_features_chunk)
632
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
633
+ else:
634
+ image_aux_features = vision_tower_aux(image_aux)
635
+ return image_aux_features
636
+ else:
637
+ for image_aux, vision_tower_aux in zip(
638
+ image_aux_list, vision_tower_aux_list
639
+ ):
640
+ if image_aux.shape[0] > chunk_size:
641
+ image_aux_features_chunks = []
642
+ for start_idx in range(0, image_aux.shape[0], chunk_size):
643
+ end_idx = min(start_idx + chunk_size, image_aux.shape[0])
644
+ chunk = image_aux[start_idx:end_idx]
645
+ image_aux_features_chunk = vision_tower_aux(chunk)
646
+ image_aux_features_chunks.append(image_aux_features_chunk)
647
+ image_aux_features = torch.cat(image_aux_features_chunks, dim=0)
648
+ else:
649
+ image_aux_features = vision_tower_aux(image_aux)
650
+ image_aux_features_list.append(image_aux_features)
651
+ return image_aux_features_list
652
+
653
+ def select_frame(
654
+ self,
655
+ feature_list,
656
+ split_sizes,
657
+ input_ids,
658
+ new_image_aux_list,
659
+ image_sizes,
660
+ window_size=16,
661
+ threshold=0.83,
662
+ ):
663
+ dino_features_batch = torch.split(feature_list, split_sizes, dim=0)
664
+ new_image_aux_batch_0 = torch.split(new_image_aux_list[0], split_sizes, dim=0)
665
+ new_image_aux_batch_1 = torch.split(new_image_aux_list[1], split_sizes, dim=0)
666
+ new_split_sizes = []
667
+ selected_frames_all_0 = []
668
+ selected_frames_all_1 = []
669
+ selected_frames_feature_all = []
670
+ selected_frame_indices_all = []
671
+ for i_batch, frame_features in enumerate(dino_features_batch):
672
+ try:
673
+ if "llama" in self.get_model().config.model_type:
674
+ text_len = torch.where(input_ids[i_batch] == 128002)[-1][0]
675
+ else:
676
+ text_len = torch.where(input_ids[i_batch] == 151643)[-1][0]
677
+ except:
678
+ text_len = len(input_ids[i_batch])
679
+ original_width, original_height = image_sizes[i_batch]
680
+ if getattr(self.get_model().config, "highres", False):
681
+ token_per_frame = self.get_model().config.lowres_token ** 2
682
+ else:
683
+ token_per_frame = self.get_model().config.image_token_len
684
+ # current_height, current_width = token_per_side, token_per_side
685
+ # original_aspect_ratio = original_width / original_height
686
+ # current_aspect_ratio = current_width / current_height
687
+ # if original_aspect_ratio > current_aspect_ratio:
688
+ # scale_factor = current_width / original_width
689
+ # new_height = int(original_height * scale_factor)
690
+ # padding = math.ceil((current_height - new_height) / 2.0)
691
+ # token_per_frame = (
692
+ # current_height - padding * 2
693
+ # ) * token_per_side + token_per_side
694
+ # else:
695
+ # scale_factor = current_height / original_height
696
+ # new_width = int(original_width * scale_factor)
697
+ # padding = math.ceil((current_width - new_width) / 2.0)
698
+ # token_per_frame = (current_width - padding * 2) * token_per_side + (
699
+ # current_width - padding * 2
700
+ # )
701
+ # token_per_frame = (
702
+ # token_per_side**2 if token_per_frame < 1 else token_per_frame
703
+ # )
704
+ max_num_frames = max(
705
+ 1,
706
+ (
707
+ self.get_model().config.tokenizer_model_max_length
708
+ - text_len
709
+ - getattr(self.get_model().config, "inference_max_length", 16)
710
+ )
711
+ // token_per_frame,
712
+ )
713
+ if len(frame_features) < max_num_frames:
714
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch])
715
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch])
716
+ selected_frames_feature_all.append(frame_features)
717
+ new_split_sizes.append(len(frame_features))
718
+ selected_frame_indices_all.append(torch.arange(len(frame_features)))
719
+ continue
720
+
721
+ num_segments = len(frame_features) // window_size
722
+ if num_segments == 0:
723
+ query_feature = frame_features.flatten(1, 2)
724
+ query_feature = query_feature / torch.norm(
725
+ (query_feature), dim=1, keepdim=True
726
+ )
727
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
728
+ similarities[len(frame_features) // 2] = 0
729
+ indices = torch.where(similarities < threshold)[0]
730
+ selected_frame_indices_all.append(indices)
731
+ selected_frames_all_0.append(new_image_aux_batch_0[i_batch][indices])
732
+ selected_frames_all_1.append(new_image_aux_batch_1[i_batch][indices])
733
+ selected_frames_feature_all.append(frame_features[indices])
734
+ new_split_sizes.append(len(indices))
735
+ continue
736
+ segments_frames_0 = []
737
+ segments_frames_1 = []
738
+ segments_features = []
739
+ for start_idx in range(0, len(frame_features), window_size):
740
+ end_idx = min(start_idx + window_size, len(frame_features))
741
+ segments_frames_0.append(
742
+ new_image_aux_batch_0[i_batch][start_idx:end_idx]
743
+ )
744
+ segments_frames_1.append(
745
+ new_image_aux_batch_1[i_batch][start_idx:end_idx]
746
+ )
747
+ segments_features.append(frame_features[start_idx:end_idx])
748
+ selected_frames_0 = []
749
+ selected_frames_1 = []
750
+ selected_features = []
751
+ selected_frame_indices = []
752
+ for i, segment in enumerate(segments_features):
753
+ query_feature = segment.flatten(1, 2)
754
+ query_feature = query_feature / torch.norm(
755
+ (query_feature), dim=1, keepdim=True
756
+ )
757
+ similarities = torch.mean(query_feature @ query_feature.T, dim=1)
758
+ similarities[len(segment) // 2] = 0
759
+ indices = torch.where(similarities < threshold)[0]
760
+ selected_frames_0.append(segments_frames_0[i][indices])
761
+ selected_frames_1.append(segments_frames_1[i][indices])
762
+ selected_features.append(segment[indices])
763
+ selected_frame_indices.extend(indices + i * window_size)
764
+ selected_frames_0 = torch.cat(selected_frames_0, dim=0)
765
+ selected_frames_1 = torch.cat(selected_frames_1, dim=0)
766
+ selected_features = torch.cat(selected_features, dim=0)
767
+ selected_frame_indices = torch.tensor(selected_frame_indices)
768
+ # ablation
769
+ max_num_frames = 400 # in case of OOM
770
+ if len(selected_frames_0) > max_num_frames:
771
+ interval = len(selected_frames_0) / float(max_num_frames)
772
+ indices = [int(interval * i) for i in range(max_num_frames)]
773
+ new_split_sizes.append(len(indices))
774
+ selected_frames_all_0.append(selected_frames_0[indices])
775
+ selected_frames_all_1.append(selected_frames_1[indices])
776
+ selected_frames_feature_all.append(selected_features[indices])
777
+ selected_frame_indices = selected_frame_indices[indices]
778
+ else:
779
+ new_split_sizes.append(len(selected_frames_0))
780
+ selected_frames_all_0.append(selected_frames_0)
781
+ selected_frames_all_1.append(selected_frames_1)
782
+ selected_frames_feature_all.append(selected_features)
783
+ selected_frame_indices_all.append(selected_frame_indices)
784
+ selected_frames_all_0 = torch.cat(selected_frames_all_0, dim=0)
785
+ selected_frames_all_1 = torch.cat(selected_frames_all_1, dim=0)
786
+ selected_frames_feature_all = torch.cat(selected_frames_feature_all, dim=0)
787
+ return (
788
+ selected_frames_feature_all,
789
+ new_split_sizes,
790
+ [selected_frames_all_0, selected_frames_all_1],
791
+ selected_frame_indices_all,
792
+ )
793
+
794
+ def prepare_inputs_labels_for_multimodal(
795
+ self,
796
+ input_ids,
797
+ position_ids,
798
+ attention_mask,
799
+ past_key_values,
800
+ labels,
801
+ images,
802
+ image_aux_attention_masks_list=None,
803
+ image_sizes=None,
804
+ ):
805
+ # vision_tower = self.get_vision_tower()
806
+ vision_tower_aux_list = self.get_model().get_vision_tower_aux_list()
807
+ if vision_tower_aux_list is None or images is None or input_ids.shape[1] == 1:
808
+ return (
809
+ input_ids,
810
+ position_ids,
811
+ attention_mask,
812
+ past_key_values,
813
+ None,
814
+ labels,
815
+ None,
816
+ None,
817
+ None,
818
+ None,
819
+ )
820
+
821
+ image_aux_list = images
822
+
823
+ split_sizes = None
824
+
825
+ if type(image_aux_list[0]) is list or image_aux_list[0].ndim == 5:
826
+ split_sizes_ori = [
827
+ 1 if image.ndim == 3 else image.shape[0] for image in image_aux_list[0]
828
+ ]
829
+ new_image_aux_list = []
830
+ for image_aux in image_aux_list:
831
+ if type(image_aux) is list:
832
+ image_aux = [
833
+ x.unsqueeze(0) if x.ndim == 3 else x for x in image_aux
834
+ ]
835
+ concat_image_aux = torch.cat([image for image in image_aux], dim=0)
836
+ new_image_aux_list.append(concat_image_aux)
837
+ image_aux_features_dino = self.encode_images(
838
+ new_image_aux_list, encode_type="dino"
839
+ )
840
+
841
+ (
842
+ image_aux_features_dino,
843
+ split_sizes,
844
+ new_image_aux_list,
845
+ selected_frame_indices_all,
846
+ ) = self.select_frame(
847
+ image_aux_features_dino,
848
+ split_sizes_ori,
849
+ input_ids,
850
+ new_image_aux_list,
851
+ image_sizes,
852
+ threshold=getattr(self.get_model().config, "dino_threshold", 0.83),
853
+ )
854
+
855
+ image_aux_features_siglip = self.encode_images(
856
+ new_image_aux_list, encode_type="siglip"
857
+ )
858
+ image_aux_features_list = [
859
+ image_aux_features_siglip,
860
+ image_aux_features_dino,
861
+ ]
862
+
863
+ bs = image_aux_features_list[0].shape[0]
864
+ dtype = new_image_aux_list[0].dtype
865
+
866
+ frame_sizes = []
867
+ for i in range(len(image_sizes)):
868
+ for j in range(split_sizes[i]):
869
+ frame_sizes.append(image_sizes[i])
870
+ image_sizes = frame_sizes
871
+ else:
872
+ image_aux_features_list = self.encode_images(image_aux_list)
873
+ bs = image_aux_list[0].shape[0]
874
+ dtype = image_aux_list[0].dtype
875
+
876
+ image_token_len = self.get_model().config.image_token_len
877
+ query_num_list = self.get_model().config.query_num_list
878
+
879
+ final_height = final_width = int(image_token_len**0.5)
880
+
881
+ final_image_features_list = []
882
+ final_image_features_down_list = []
883
+
884
+ # only needed for sva
885
+ vision_tower_aux_feature_list_final = None
886
+ vision_tower_aux_attention_masks_list_final = None
887
+ global_context_feature_final = None
888
+
889
+ if self.get_model().config.mm_projector_type == "sva":
890
+ vision_tower_aux_feature_list = []
891
+ vision_tower_aux_attention_masks_list = []
892
+ # get vision tokens from each vision tower
893
+ for aux_i in range(len(vision_tower_aux_list)):
894
+ image_aux_features = image_aux_features_list[aux_i]
895
+
896
+ image_aux_features = getattr(
897
+ self.get_model(), "mm_projector_aux_{}".format(aux_i)
898
+ )(image_aux_features).to(dtype)
899
+ if aux_i == 0:
900
+ global_context_feature = image_aux_features.mean(1).view(
901
+ bs, 1, 1, -1
902
+ )
903
+
904
+ vision_tower_aux_feature_list.append(image_aux_features)
905
+ input_mix_res = True
906
+ input_high_res = True
907
+ # perform vision sampling for each query group
908
+ for query_group_i, query_num in enumerate(query_num_list):
909
+ query_features_i = (
910
+ self.get_model()
911
+ .vision_query[query_group_i, :]
912
+ .view(1, 1, 1, -1)
913
+ .expand(bs, query_num, -1, -1)
914
+ )
915
+ global_context_feature_i = global_context_feature.expand(
916
+ -1, query_num, 1, -1
917
+ ).flatten(0, 1)
918
+ query_side_len = int(query_num**0.5)
919
+ if IS_XLA_AVAILABLE:
920
+ (
921
+ vision_tower_aux_feature_list_i,
922
+ vision_tower_aux_attention_masks_list_i,
923
+ ) = self.rearrange_vision_tower_features_train(
924
+ vision_tower_aux_feature_list,
925
+ image_aux_attention_masks_list,
926
+ query_side_len,
927
+ )
928
+ else:
929
+ (
930
+ vision_tower_aux_feature_list_i,
931
+ vision_tower_aux_attention_masks_list_i,
932
+ ) = self.rearrange_vision_tower_features_inference(
933
+ vision_tower_aux_feature_list, query_side_len, image_sizes
934
+ )
935
+
936
+ query_features_i = getattr(
937
+ self.get_model(), "vision_sampler_{}".format(query_group_i)
938
+ )(
939
+ query_features_i.flatten(0, 1),
940
+ global_context_feature_i,
941
+ *vision_tower_aux_feature_list_i,
942
+ *vision_tower_aux_attention_masks_list_i,
943
+ )
944
+ query_features_i = query_features_i.view(bs, query_num, -1)
945
+
946
+ if split_sizes is not None:
947
+ try:
948
+ if "llama" in self.get_model().config.model_type:
949
+ text_len = torch.where(input_ids[0] == 128002)[-1][0]
950
+ else:
951
+ text_len = torch.where(input_ids[0] == 151643)[-1][0]
952
+ except:
953
+ text_len = len(input_ids[0])
954
+ max_visual_len = (
955
+ self.get_model().config.tokenizer_model_max_length
956
+ - text_len
957
+ - getattr(self.get_model().config, "inference_max_length", 16)
958
+ )
959
+ max_num_frames = max(
960
+ 1,
961
+ math.floor(max_visual_len // (final_height * final_width)),
962
+ )
963
+ max_num_frames_low = max(
964
+ 1,
965
+ math.floor(
966
+ max_visual_len
967
+ // (self.get_model().config.lowres_token ** 2)
968
+ ),
969
+ )
970
+ if split_sizes[0] < max_num_frames:
971
+ input_mix_res = False
972
+ elif split_sizes[0] > max_num_frames_low:
973
+ input_mix_res = False
974
+ input_high_res = False
975
+
976
+ # input_mix_res = False # ablation
977
+
978
+ if (getattr(self.config, "highres", False)) and input_mix_res:
979
+ _query_features_i = (
980
+ query_features_i.permute(0, 2, 1)
981
+ .contiguous()
982
+ .view(bs, -1, query_side_len, query_side_len)
983
+ )
984
+ _query_features_i = F.interpolate(
985
+ _query_features_i.float(),
986
+ size=(
987
+ self.get_model().config.lowres_token,
988
+ self.get_model().config.lowres_token,
989
+ ),
990
+ mode="bilinear",
991
+ align_corners=False,
992
+ ).to(dtype=query_features_i.dtype)
993
+ _query_features_i = (
994
+ _query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
995
+ )
996
+ final_image_features_down_list.append(_query_features_i)
997
+
998
+ # interpolate to the final target size
999
+ if query_side_len != final_height:
1000
+ query_features_i = (
1001
+ query_features_i.permute(0, 2, 1)
1002
+ .contiguous()
1003
+ .view(bs, -1, query_side_len, query_side_len)
1004
+ )
1005
+ if input_high_res:
1006
+ query_features_i = F.interpolate(
1007
+ query_features_i.float(),
1008
+ size=(final_height, final_width),
1009
+ mode="bilinear",
1010
+ align_corners=False,
1011
+ ).to(dtype=query_features_i.dtype)
1012
+ else:
1013
+ query_features_i = F.interpolate(
1014
+ query_features_i.float(),
1015
+ size=(8, 8),
1016
+ mode="bilinear",
1017
+ align_corners=False,
1018
+ ).to(dtype=query_features_i.dtype)
1019
+ query_features_i = (
1020
+ query_features_i.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
1021
+ )
1022
+ final_image_features_list.append(query_features_i)
1023
+
1024
+ if IS_XLA_AVAILABLE:
1025
+ (
1026
+ vision_tower_aux_feature_list_final,
1027
+ vision_tower_aux_attention_masks_list_final,
1028
+ ) = self.rearrange_vision_tower_features_train(
1029
+ vision_tower_aux_feature_list,
1030
+ image_aux_attention_masks_list,
1031
+ final_height,
1032
+ )
1033
+ global_context_feature_final = global_context_feature.expand(
1034
+ -1, final_height * final_width, 1, -1
1035
+ ).flatten(0, 1)
1036
+ else:
1037
+ final_image_features_list = image_aux_features_list
1038
+
1039
+ image_features = torch.cat(final_image_features_list, -1)
1040
+ image_features = self.get_model().mm_projector(image_features).to(dtype)
1041
+
1042
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1043
+ image_features_down = torch.cat(final_image_features_down_list, -1)
1044
+ image_features_down = (
1045
+ self.get_model().mm_projector(image_features_down).to(dtype)
1046
+ )
1047
+
1048
+ if IS_XLA_AVAILABLE:
1049
+ image_features = image_features.view(
1050
+ image_features.shape[0], final_height, final_width, -1
1051
+ )
1052
+ image_features = torch.cat(
1053
+ (
1054
+ image_features,
1055
+ self.model.image_newline[None, None, None, :].expand(
1056
+ image_features.shape[0], final_height, 1, -1
1057
+ ),
1058
+ ),
1059
+ dim=2,
1060
+ )
1061
+ image_features = image_features.flatten(1, 2)
1062
+ final_size = [(final_height, final_width)] * bs
1063
+
1064
+ else:
1065
+ image_features = image_features.view(bs, final_height, final_width, -1)
1066
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1067
+ image_features_down = image_features_down.view(
1068
+ bs,
1069
+ self.get_model().config.lowres_token,
1070
+ self.get_model().config.lowres_token,
1071
+ -1,
1072
+ )
1073
+ image_features_unpadded = []
1074
+ image_features_downsample = []
1075
+ final_size = []
1076
+ if self.get_model().config.mm_projector_type == "sva":
1077
+ (
1078
+ vision_tower_aux_feature_list_final,
1079
+ vision_tower_aux_attention_masks_list_final,
1080
+ ) = self.rearrange_vision_tower_features_inference(
1081
+ vision_tower_aux_feature_list, final_height, image_sizes, unpad=True
1082
+ )
1083
+ global_context_feature_final = []
1084
+ for batch_i in range(bs):
1085
+ cur_image_feature = image_features[batch_i]
1086
+ image_size = image_sizes[batch_i]
1087
+
1088
+ cur_image_feature = unpad_image(
1089
+ cur_image_feature.unsqueeze(0), image_size
1090
+ )
1091
+
1092
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1093
+ try: # fix bug for some invalid image
1094
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1095
+ final_size.append((cur_h, cur_w))
1096
+ except:
1097
+ # print(f"invalid after unpad {image_features[batch_i].shape}, {image_sizes[batch_i]}", flush=True)
1098
+ cur_image_feature = image_features[batch_i].unsqueeze(0)
1099
+ image_size = image_sizes[batch_i]
1100
+ cur_h, cur_w = cur_image_feature.shape[1:3]
1101
+ cur_image_feature = cur_image_feature.view(1, cur_h, cur_w, -1)
1102
+ final_size.append((cur_h, cur_w))
1103
+
1104
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1105
+ cur_image_feature_down = unpad_image(
1106
+ image_features_down[batch_i].unsqueeze(0),
1107
+ (
1108
+ int(
1109
+ image_size[0]
1110
+ / (
1111
+ image_token_len**0.5
1112
+ / self.get_model().config.lowres_token
1113
+ )
1114
+ ),
1115
+ int(
1116
+ image_size[1]
1117
+ / (
1118
+ image_token_len**0.5
1119
+ / self.get_model().config.lowres_token
1120
+ )
1121
+ ),
1122
+ ),
1123
+ )
1124
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1125
+
1126
+ try: # fix bug for some invalid image
1127
+ cur_image_feature_down = cur_image_feature_down.view(
1128
+ 1, _cur_h, _cur_w, -1
1129
+ )
1130
+ except:
1131
+ print("invalid after unpad", flush=True)
1132
+ cur_image_feature_down = image_features_down[batch_i].unsqueeze(
1133
+ 0
1134
+ )
1135
+ _cur_h, _cur_w = cur_image_feature_down.shape[1:3]
1136
+ cur_image_feature_down = cur_image_feature_down.view(
1137
+ 1, _cur_h, _cur_w, -1
1138
+ )
1139
+
1140
+ cur_image_feature_down = torch.cat(
1141
+ (
1142
+ cur_image_feature_down,
1143
+ self.model.image_newline.view(1, 1, 1, -1)
1144
+ .expand(1, _cur_h, 1, -1)
1145
+ .to(cur_image_feature_down.device),
1146
+ ),
1147
+ dim=2,
1148
+ ).flatten(1, 2)
1149
+
1150
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1151
+ frame_pos = (
1152
+ self.get_model()
1153
+ .get_frame_pos(torch.arange(1))
1154
+ .to(cur_image_feature_down.device)
1155
+ .to(cur_image_feature_down.dtype)
1156
+ )
1157
+ cur_image_feature_down += frame_pos
1158
+
1159
+ image_features_downsample.append(cur_image_feature_down.squeeze(0))
1160
+
1161
+ cur_image_feature = torch.cat(
1162
+ (
1163
+ cur_image_feature,
1164
+ self.model.image_newline.view(1, 1, 1, -1)
1165
+ .expand(1, cur_h, 1, -1)
1166
+ .to(cur_image_feature.device),
1167
+ ),
1168
+ dim=2,
1169
+ )
1170
+
1171
+ if split_sizes is None and getattr(self.config, "frame_pos", False):
1172
+ frame_pos = (
1173
+ self.get_model()
1174
+ .get_frame_pos(torch.arange(1))
1175
+ .to(cur_image_feature.device)
1176
+ .to(cur_image_feature.dtype)
1177
+ )
1178
+ cur_image_feature += frame_pos
1179
+
1180
+ cur_image_feature = cur_image_feature.flatten(1, 2)
1181
+ image_features_unpadded.append(cur_image_feature.squeeze(0))
1182
+
1183
+ if self.get_model().config.mm_projector_type == "sva":
1184
+ cur_global_context_feature = global_context_feature[batch_i].expand(
1185
+ cur_h * cur_w, 1, -1
1186
+ )
1187
+ global_context_feature_final.append(cur_global_context_feature)
1188
+ if self.get_model().config.mm_projector_type == "sva":
1189
+ global_context_feature_final = torch.cat(
1190
+ global_context_feature_final, 0
1191
+ )
1192
+
1193
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1194
+ image_features = image_features_downsample
1195
+ else:
1196
+ image_features = image_features_unpadded
1197
+
1198
+ # TODO: image start / end is not implemented here to support pretraining.
1199
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
1200
+ self.config, "mm_use_im_start_end", False
1201
+ ):
1202
+ raise NotImplementedError
1203
+
1204
+ split_image_features_unpadded = None
1205
+ frame_split_sizes = None
1206
+
1207
+ if split_sizes is not None:
1208
+ split_image_features = []
1209
+ split_image_features_unpadded = (
1210
+ []
1211
+ if (getattr(self.config, "highres", False)) and input_mix_res
1212
+ else None
1213
+ )
1214
+ start_idx = 0
1215
+ for split_batch_idx, split_size in enumerate(split_sizes):
1216
+ if isinstance(image_features[start_idx : start_idx + split_size], list):
1217
+ if getattr(self.config, "frame_pos", False):
1218
+ frame_feature = torch.cat(
1219
+ image_features[start_idx : start_idx + split_size], dim=0
1220
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1221
+ frame_pos = (
1222
+ self.get_model()
1223
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1224
+ .to(frame_feature.device)
1225
+ .to(frame_feature.dtype)
1226
+ )
1227
+ frame_feature += frame_pos
1228
+ split_image_features.append(
1229
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1230
+ )
1231
+ else:
1232
+ split_image_features.append(
1233
+ torch.cat(
1234
+ image_features[start_idx : start_idx + split_size],
1235
+ dim=0,
1236
+ )
1237
+ )
1238
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1239
+ if getattr(self.config, "frame_pos", False):
1240
+ frame_feature = torch.cat(
1241
+ image_features_unpadded[
1242
+ start_idx : start_idx + split_size
1243
+ ],
1244
+ dim=0,
1245
+ ).reshape(split_size, -1, image_features[0].shape[-1])
1246
+ frame_pos = (
1247
+ self.get_model()
1248
+ .get_frame_pos(
1249
+ selected_frame_indices_all[split_batch_idx]
1250
+ )
1251
+ .to(frame_feature.device)
1252
+ .to(frame_feature.dtype)
1253
+ )
1254
+ frame_feature += frame_pos
1255
+ split_image_features_unpadded.append(
1256
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1257
+ )
1258
+ else:
1259
+ split_image_features_unpadded.append(
1260
+ torch.cat(
1261
+ image_features_unpadded[
1262
+ start_idx : start_idx + split_size
1263
+ ],
1264
+ dim=0,
1265
+ )
1266
+ )
1267
+ else:
1268
+ if getattr(self.config, "frame_pos", False):
1269
+ frame_feature = image_features[
1270
+ start_idx : start_idx + split_size
1271
+ ].reshape(split_size, -1, image_features[0].shape[-1])
1272
+ frame_pos = (
1273
+ self.get_model()
1274
+ .get_frame_pos(selected_frame_indices_all[split_batch_idx])
1275
+ .to(frame_feature.device)
1276
+ .to(frame_feature.dtype)
1277
+ )
1278
+ frame_feature += frame_pos
1279
+ split_image_features.append(
1280
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1281
+ )
1282
+ else:
1283
+ split_image_features.append(
1284
+ image_features[start_idx : start_idx + split_size]
1285
+ )
1286
+ if (getattr(self.config, "highres", False)) and input_mix_res:
1287
+ if getattr(self.config, "frame_pos", False):
1288
+ frame_feature = image_features_unpadded[
1289
+ start_idx : start_idx + split_size
1290
+ ]
1291
+ frame_pos = (
1292
+ self.get_model()
1293
+ .get_frame_pos(
1294
+ selected_frame_indices_all[split_batch_idx]
1295
+ )
1296
+ .to(frame_feature.device)
1297
+ .to(frame_feature.dtype)
1298
+ )
1299
+ frame_feature += frame_pos
1300
+ split_image_features_unpadded.append(
1301
+ frame_feature.reshape(-1, image_features[0].shape[-1])
1302
+ )
1303
+ else:
1304
+ split_image_features_unpadded.append(
1305
+ image_features_unpadded[
1306
+ start_idx : start_idx + split_size
1307
+ ]
1308
+ )
1309
+ start_idx += split_size
1310
+ image_features = split_image_features
1311
+ frame_split_sizes = split_sizes
1312
+
1313
+ _labels = labels
1314
+ _position_ids = position_ids
1315
+ _attention_mask = attention_mask
1316
+ if attention_mask is None:
1317
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1318
+ else:
1319
+ attention_mask = attention_mask.bool()
1320
+ if position_ids is None:
1321
+ position_ids = torch.arange(
1322
+ 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
1323
+ )
1324
+ if labels is None:
1325
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1326
+
1327
+ # remove the padding using attention_mask -- FIXME
1328
+ _input_ids = input_ids
1329
+
1330
+ attention_mask = attention_mask | (input_ids == IMAGE_TOKEN_INDEX)
1331
+
1332
+ input_ids = [
1333
+ cur_input_ids[cur_attention_mask]
1334
+ for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
1335
+ ]
1336
+ labels = [
1337
+ cur_labels[cur_attention_mask]
1338
+ for cur_labels, cur_attention_mask in zip(labels, attention_mask)
1339
+ ]
1340
+
1341
+ new_input_embeds = []
1342
+ new_labels = []
1343
+ image_token_indices_batch = []
1344
+ cur_image_idx = 0
1345
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1346
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1347
+ if num_images == 0:
1348
+ cur_image_features = image_features[cur_image_idx]
1349
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
1350
+ cur_input_embeds = torch.cat(
1351
+ [cur_input_embeds_1, cur_image_features[0:0]], dim=0
1352
+ )
1353
+ new_input_embeds.append(cur_input_embeds)
1354
+ new_labels.append(labels[batch_idx])
1355
+ cur_image_idx += 1
1356
+ continue
1357
+
1358
+ image_token_indices = (
1359
+ [-1]
1360
+ + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
1361
+ + [cur_input_ids.shape[0]]
1362
+ )
1363
+ image_token_indices_batch.append(
1364
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()[0]
1365
+ )
1366
+ cur_input_ids_noim = []
1367
+ cur_labels = labels[batch_idx]
1368
+ cur_labels_noim = []
1369
+ for i in range(len(image_token_indices) - 1):
1370
+ cur_input_ids_noim.append(
1371
+ cur_input_ids[
1372
+ image_token_indices[i] + 1 : image_token_indices[i + 1]
1373
+ ]
1374
+ )
1375
+ cur_labels_noim.append(
1376
+ cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
1377
+ )
1378
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1379
+ cur_input_embeds = self.get_model().embed_tokens(
1380
+ torch.cat(cur_input_ids_noim)
1381
+ )
1382
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1383
+ cur_new_input_embeds = []
1384
+ cur_new_labels = []
1385
+
1386
+ text_len = sum([x.shape[0] for x in cur_input_embeds_no_im])
1387
+ visual_len = len(image_features[cur_image_idx])
1388
+ max_visual_len = (
1389
+ self.get_model().config.tokenizer_model_max_length
1390
+ - getattr(self.get_model().config, "inference_max_length", 16)
1391
+ - text_len
1392
+ )
1393
+ mix_token = False
1394
+
1395
+ # ablation mix
1396
+ if (
1397
+ input_mix_res
1398
+ and (
1399
+ self.get_model().config.image_token_len
1400
+ > getattr(self.get_model().config, "lowres_token", 8) ** 2
1401
+ )
1402
+ and frame_split_sizes is not None
1403
+ and getattr(self.config, "highres", False)
1404
+ ):
1405
+ if max_visual_len > visual_len:
1406
+ visual_emb = image_features[cur_image_idx]
1407
+ text_emb = cur_input_embeds_no_im[-1]
1408
+ highres_num = math.floor(
1409
+ (max_visual_len - visual_len)
1410
+ / (
1411
+ split_image_features_unpadded[cur_image_idx].shape[0]
1412
+ // frame_split_sizes[cur_image_idx]
1413
+ - visual_emb.shape[0] // frame_split_sizes[cur_image_idx]
1414
+ )
1415
+ )
1416
+ if highres_num >= 1:
1417
+ mix_token = True
1418
+ sim = torch.matmul(visual_emb, text_emb.transpose(0, 1)).mean(
1419
+ dim=-1
1420
+ )
1421
+ sim_frame = sim.reshape(
1422
+ frame_split_sizes[cur_image_idx], -1
1423
+ ).mean(dim=-1)
1424
+ highres_num = min(highres_num, sim_frame.shape[0])
1425
+ top_values, top_indices = torch.topk(sim_frame, highres_num)
1426
+ if len(top_indices) > 0:
1427
+ sorted_indices = torch.sort(top_indices)[1]
1428
+ top_indices = top_indices[sorted_indices]
1429
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1430
+ frame_split_sizes[cur_image_idx],
1431
+ -1,
1432
+ image_features[cur_image_idx].shape[-1],
1433
+ )
1434
+ visual_emb_frame_highres = split_image_features_unpadded[
1435
+ cur_image_idx
1436
+ ].reshape(
1437
+ frame_split_sizes[cur_image_idx],
1438
+ -1,
1439
+ split_image_features_unpadded[cur_image_idx].shape[-1],
1440
+ )
1441
+ current_point = 0
1442
+ mix_visual_emb_frame = []
1443
+ for frame_i in range(len(visual_emb_frame)):
1444
+ if current_point > len(top_indices) - 1:
1445
+ mix_visual_emb_frame.append(
1446
+ visual_emb_frame[frame_i]
1447
+ )
1448
+ continue
1449
+ if frame_i == top_indices[current_point]:
1450
+ mix_visual_emb_frame.append(
1451
+ visual_emb_frame_highres[frame_i]
1452
+ )
1453
+ current_point += 1
1454
+ else:
1455
+ mix_visual_emb_frame.append(
1456
+ visual_emb_frame[frame_i]
1457
+ )
1458
+ image_features[cur_image_idx] = torch.cat(
1459
+ mix_visual_emb_frame, dim=0
1460
+ )
1461
+ # ablation drop
1462
+
1463
+ if (
1464
+ max_visual_len < visual_len
1465
+ and frame_split_sizes is not None
1466
+ and not mix_token
1467
+ ):
1468
+ visual_emb_frame = image_features[cur_image_idx].reshape(
1469
+ frame_split_sizes[cur_image_idx],
1470
+ -1,
1471
+ image_features[cur_image_idx].shape[-1],
1472
+ )
1473
+
1474
+ sim = F.cosine_similarity(
1475
+ visual_emb_frame[:-1],
1476
+ visual_emb_frame[1:],
1477
+ dim=-1,
1478
+ )
1479
+
1480
+ new_visual_emb_frames = []
1481
+ for start_idx in range(0, len(visual_emb_frame), 8):
1482
+ end_idx = min(start_idx + 8, len(visual_emb_frame))
1483
+ chunk_feature = visual_emb_frame[start_idx:end_idx] # 8, HW, C
1484
+ if len(chunk_feature) == 1:
1485
+ new_visual_emb_frames.append(chunk_feature[0])
1486
+ continue
1487
+ sim = F.cosine_similarity(
1488
+ chunk_feature[0]
1489
+ .unsqueeze(0)
1490
+ .repeat_interleave(len(chunk_feature[1:]), dim=0),
1491
+ chunk_feature[1:],
1492
+ dim=-1,
1493
+ )
1494
+ new_visual_emb_frame = torch.cat(
1495
+ [
1496
+ chunk_feature[0],
1497
+ chunk_feature[1:].flatten(0, 1)[
1498
+ sim.flatten(0, 1)
1499
+ < getattr(
1500
+ self.get_model().config, "drop_threshold", 0.7
1501
+ )
1502
+ ],
1503
+ ],
1504
+ dim=0,
1505
+ )
1506
+ new_visual_emb_frames.append(new_visual_emb_frame)
1507
+
1508
+ reduced_visual_len = sum([x.shape[0] for x in new_visual_emb_frames])
1509
+
1510
+ if reduced_visual_len > max_visual_len:
1511
+ force_remove = math.ceil(
1512
+ (reduced_visual_len - max_visual_len)
1513
+ / len(new_visual_emb_frames)
1514
+ )
1515
+ for chunk_i in range(len(new_visual_emb_frames)):
1516
+ new_visual_emb_frames[chunk_i] = new_visual_emb_frames[chunk_i][
1517
+ :-force_remove
1518
+ ]
1519
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1520
+ else:
1521
+ new_visual_emb_frames = torch.cat(new_visual_emb_frames, dim=0)
1522
+
1523
+ image_features[cur_image_idx] = new_visual_emb_frames[:max_visual_len]
1524
+
1525
+ for i in range(num_images + 1):
1526
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1527
+ cur_new_labels.append(cur_labels_noim[i])
1528
+ if i < num_images:
1529
+ cur_image_features = image_features[cur_image_idx]
1530
+ cur_image_idx += 1
1531
+ cur_new_input_embeds.append(cur_image_features)
1532
+ cur_new_labels.append(
1533
+ torch.full(
1534
+ (cur_image_features.shape[0],),
1535
+ IGNORE_INDEX,
1536
+ device=cur_labels.device,
1537
+ dtype=cur_labels.dtype,
1538
+ )
1539
+ )
1540
+
1541
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1542
+
1543
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1544
+ cur_new_labels = torch.cat(cur_new_labels)
1545
+
1546
+ new_input_embeds.append(cur_new_input_embeds)
1547
+ new_labels.append(cur_new_labels)
1548
+
1549
+ # Truncate sequences to max length as image embeddings can make the sequence longer
1550
+ tokenizer_model_max_length = getattr(
1551
+ self.config, "tokenizer_model_max_length", None
1552
+ )
1553
+ if tokenizer_model_max_length is not None:
1554
+ new_input_embeds = [
1555
+ x[:tokenizer_model_max_length] for x in new_input_embeds
1556
+ ]
1557
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1558
+
1559
+ # Combine them
1560
+ max_len = max(x.shape[0] for x in new_input_embeds)
1561
+ batch_size = len(new_input_embeds)
1562
+
1563
+ new_input_embeds_padded = []
1564
+ new_labels_padded = torch.full(
1565
+ (batch_size, max_len),
1566
+ IGNORE_INDEX,
1567
+ dtype=new_labels[0].dtype,
1568
+ device=new_labels[0].device,
1569
+ )
1570
+ attention_mask = torch.zeros(
1571
+ (batch_size, max_len),
1572
+ dtype=attention_mask.dtype,
1573
+ device=attention_mask.device,
1574
+ )
1575
+ position_ids = torch.zeros(
1576
+ (batch_size, max_len),
1577
+ dtype=position_ids.dtype,
1578
+ device=position_ids.device,
1579
+ )
1580
+
1581
+ for i, (cur_new_embed, cur_new_labels) in enumerate(
1582
+ zip(new_input_embeds, new_labels)
1583
+ ):
1584
+ cur_len = cur_new_embed.shape[0]
1585
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
1586
+ new_input_embeds_padded.append(
1587
+ torch.cat(
1588
+ (
1589
+ torch.zeros(
1590
+ (max_len - cur_len, cur_new_embed.shape[1]),
1591
+ dtype=cur_new_embed.dtype,
1592
+ device=cur_new_embed.device,
1593
+ ),
1594
+ cur_new_embed,
1595
+ ),
1596
+ dim=0,
1597
+ )
1598
+ )
1599
+ if cur_len > 0:
1600
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1601
+ attention_mask[i, -cur_len:] = True
1602
+ position_ids[i, -cur_len:] = torch.arange(
1603
+ 0,
1604
+ cur_len,
1605
+ dtype=position_ids.dtype,
1606
+ device=position_ids.device,
1607
+ )
1608
+ else:
1609
+ new_input_embeds_padded.append(
1610
+ torch.cat(
1611
+ (
1612
+ cur_new_embed,
1613
+ torch.zeros(
1614
+ (max_len - cur_len, cur_new_embed.shape[1]),
1615
+ dtype=cur_new_embed.dtype,
1616
+ device=cur_new_embed.device,
1617
+ ),
1618
+ ),
1619
+ dim=0,
1620
+ )
1621
+ )
1622
+ if cur_len > 0:
1623
+ new_labels_padded[i, :cur_len] = cur_new_labels
1624
+ attention_mask[i, :cur_len] = True
1625
+ position_ids[i, :cur_len] = torch.arange(
1626
+ 0,
1627
+ cur_len,
1628
+ dtype=position_ids.dtype,
1629
+ device=position_ids.device,
1630
+ )
1631
+
1632
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1633
+
1634
+ if _labels is None:
1635
+ new_labels = None
1636
+ else:
1637
+ new_labels = new_labels_padded
1638
+
1639
+ if _attention_mask is None:
1640
+ attention_mask = None
1641
+ else:
1642
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1643
+
1644
+ if _position_ids is None:
1645
+ position_ids = None
1646
+
1647
+ return (
1648
+ None,
1649
+ position_ids,
1650
+ attention_mask,
1651
+ past_key_values,
1652
+ new_input_embeds,
1653
+ new_labels,
1654
+ vision_tower_aux_feature_list_final,
1655
+ vision_tower_aux_attention_masks_list_final,
1656
+ final_size,
1657
+ global_context_feature_final,
1658
+ )
1659
+
1660
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
1661
+ if model_args.mm_use_im_patch_token:
1662
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
1663
+ self.resize_token_embeddings(len(tokenizer))
1664
+
1665
+ if model_args.mm_use_im_start_end:
1666
+ num_new_tokens = tokenizer.add_tokens(
1667
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
1668
+ )
1669
+ self.resize_token_embeddings(len(tokenizer))
1670
+
1671
+ if num_new_tokens > 0:
1672
+ input_embeddings = self.get_input_embeddings().weight.data
1673
+ output_embeddings = self.get_output_embeddings().weight.data
1674
+
1675
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
1676
+ dim=0, keepdim=True
1677
+ )
1678
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
1679
+ dim=0, keepdim=True
1680
+ )
1681
+
1682
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
1683
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
1684
+
1685
+ if model_args.tune_mm_mlp_adapter:
1686
+ for p in self.get_input_embeddings().parameters():
1687
+ p.requires_grad = True
1688
+ for p in self.get_output_embeddings().parameters():
1689
+ p.requires_grad = False
1690
+
1691
+ if model_args.pretrain_mm_mlp_adapter:
1692
+ mm_projector_weights = torch.load(
1693
+ model_args.pretrain_mm_mlp_adapter, map_location="cpu"
1694
+ )
1695
+ embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
1696
+ assert num_new_tokens == 2
1697
+ if input_embeddings.shape == embed_tokens_weight.shape:
1698
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[
1699
+ -num_new_tokens:
1700
+ ]
1701
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
1702
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
1703
+ else:
1704
+ raise ValueError(
1705
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
1706
+ )
1707
+ elif model_args.mm_use_im_patch_token:
1708
+ if model_args.tune_mm_mlp_adapter:
1709
+ for p in self.get_input_embeddings().parameters():
1710
+ p.requires_grad = False
1711
+ for p in self.get_output_embeddings().parameters():
1712
+ p.requires_grad = False
config.json ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jadechoghari/LongVU_Qwen2_7B",
3
+ "architectures": [
4
+ "CambrianQwenForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling.CambrianConfig",
8
+ "AutoModel": "modeling.CambrianLlamaForCausalLM",
9
+ "AutoModelForCausalLM": "modeling.CambrianLlamaForCausalLM"
10
+ },
11
+ "attention_bias": false,
12
+ "attention_dropout": 0.0,
13
+ "bos_token_id": 151643,
14
+ "connect_layer": 2,
15
+ "connector_depth": 3,
16
+ "connector_only": true,
17
+ "dino_threshold": 0.83,
18
+ "drop_threshold": 0.7,
19
+ "eos_token_id": 151645,
20
+ "frame_pos": false,
21
+ "freeze_mm_mlp_adapter": false,
22
+ "hidden_act": "silu",
23
+ "hidden_size": 3584,
24
+ "highres": true,
25
+ "highres_connect": false,
26
+ "image_aspect_ratio": "pad",
27
+ "image_position": 91,
28
+ "image_token_len": 144,
29
+ "initializer_range": 0.02,
30
+ "intermediate_size": 18944,
31
+ "is_st_sampler": false,
32
+ "lowres_token": 8,
33
+ "max_position_embeddings": 32768,
34
+ "max_window_layers": 28,
35
+ "mm_patch_merge_type": "flat",
36
+ "mm_projector_lr": null,
37
+ "mm_projector_type": "sva",
38
+ "mm_use_im_patch_token": false,
39
+ "mm_use_im_start_end": false,
40
+ "mm_vision_sampler_lr": null,
41
+ "mm_vision_select_feature": "patch",
42
+ "mm_vision_select_layer": -2,
43
+ "mm_vision_tower_aux_list": [
44
+ "siglip/CLIP-ViT-SO400M-14-384",
45
+ "facebook/dinov2-giant-res378"
46
+ ],
47
+ "mm_vision_tower_aux_token_len_list": [
48
+ 576,
49
+ 576
50
+ ],
51
+ "mm_vision_tower_lr": null,
52
+ "model_type": "cambrian_qwen",
53
+ "num_attention_heads": 28,
54
+ "num_hidden_layers": 28,
55
+ "num_key_value_heads": 4,
56
+ "num_of_vision_sampler_layers": 10,
57
+ "num_query_group": 1,
58
+ "pretraining_tp": 1,
59
+ "query_num_list": [
60
+ 144
61
+ ],
62
+ "rms_norm_eps": 1e-06,
63
+ "rope_scaling": null,
64
+ "rope_theta": 1000000.0,
65
+ "sliding_window": null,
66
+ "spmd_debug": null,
67
+ "spmd_fsdp_sharding": null,
68
+ "spmd_mesh": null,
69
+ "start_of_vision_sampler_layers": 0,
70
+ "stride_of_vision_sampler_layers": 3,
71
+ "tie_word_embeddings": false,
72
+ "tokenizer_model_max_length": 10000,
73
+ "tokenizer_padding_side": "right",
74
+ "torch_dtype": "float32",
75
+ "transformers_version": "4.44.2",
76
+ "tune_mm_mlp_adapter": false,
77
+ "unfreeze_mm_vision_tower": false,
78
+ "use_cache": false,
79
+ "use_mm_proj": true,
80
+ "use_pos_skipping": false,
81
+ "use_sliding_window": false,
82
+ "vision_hidden_size": 1024,
83
+ "vision_tower_aux_token_len_list": [
84
+ 576,
85
+ 576
86
+ ],
87
+ "vocab_size": 152064
88
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e846f373072ab8e42ee7963e21514d543696ee2859c30570bb1b05a88d94f3ca
3
+ size 15343381968
modeling.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import CrossEntropyLoss
22
+
23
+ from transformers import AutoConfig, AutoModelForCausalLM
24
+ from transformers.cache_utils import Cache, DynamicCache
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPast,
29
+ CausalLMOutputWithPast,
30
+ )
31
+ from transformers.utils import logging
32
+
33
+ from .cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
34
+
35
+ IS_XLA_AVAILABLE = False
36
+
37
+ from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ class CambrianConfig(Qwen2Config):
43
+ model_type = "cambrian_qwen"
44
+
45
+ debug = "debug"
46
+
47
+
48
+ class CambrianQwenModel(CambrianMetaModel, Qwen2Model):
49
+ config_class = CambrianConfig
50
+
51
+ def __init__(self, config: Qwen2Config):
52
+ super(CambrianQwenModel, self).__init__(config)
53
+
54
+ def forward(
55
+ self,
56
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
57
+ input_ids: torch.LongTensor = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ position_ids: Optional[torch.LongTensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ use_cache: Optional[bool] = None,
63
+ output_attentions: Optional[bool] = None,
64
+ output_hidden_states: Optional[bool] = None,
65
+ return_dict: Optional[bool] = None,
66
+ cache_position: Optional[torch.LongTensor] = None,
67
+ vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
68
+ vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
69
+ final_vision_feature_size: Optional[List[tuple]] = None,
70
+ global_context_feature: Optional[torch.Tensor] = None,
71
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
72
+ output_attentions = (
73
+ output_attentions
74
+ if output_attentions is not None
75
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `config`.
76
+ else self.config.output_attentions
77
+ )
78
+ output_hidden_states = (
79
+ output_hidden_states
80
+ if output_hidden_states is not None
81
+ else self.config.output_hidden_states
82
+ )
83
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
84
+
85
+ return_dict = (
86
+ return_dict if return_dict is not None else self.config.use_return_dict
87
+ )
88
+
89
+ if (input_ids is None) ^ (inputs_embeds is not None):
90
+ raise ValueError(
91
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
92
+ )
93
+
94
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `gradient_checkpointing`.
95
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `training`.
96
+ if self.gradient_checkpointing and self.training:
97
+ if use_cache:
98
+ logger.warning_once(
99
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
100
+ )
101
+ use_cache = False
102
+
103
+ use_legacy_cache = False
104
+ if use_cache and not isinstance(past_key_values, Cache):
105
+ use_legacy_cache = True
106
+ # pyre-fixme[6]: For 1st argument expected
107
+ # `Optional[Tuple[Tuple[FloatTensor]]]` but got
108
+ # `Optional[List[FloatTensor]]`.
109
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
110
+ logger.warning_once(
111
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
112
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
113
+ )
114
+
115
+ if inputs_embeds is None:
116
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `embed_tokens`.
117
+ inputs_embeds = self.embed_tokens(input_ids)
118
+
119
+ if cache_position is None:
120
+ past_seen_tokens = (
121
+ # pyre-fixme[16]: Item `List` of `Union[List[torch._C.FloatTensor],
122
+ # DynamicCache]` has no attribute `get_seq_length`.
123
+ past_key_values.get_seq_length() if past_key_values is not None else 0
124
+ )
125
+ cache_position = torch.arange(
126
+ past_seen_tokens,
127
+ past_seen_tokens + inputs_embeds.shape[1],
128
+ device=inputs_embeds.device,
129
+ )
130
+ if position_ids is None:
131
+ position_ids = cache_position.unsqueeze(0)
132
+
133
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `_update_causal_mask`.
134
+ causal_mask = self._update_causal_mask(
135
+ attention_mask,
136
+ inputs_embeds,
137
+ cache_position,
138
+ past_key_values,
139
+ output_attentions,
140
+ )
141
+
142
+ hidden_states = inputs_embeds
143
+
144
+ # decoder layers
145
+ all_hidden_states = () if output_hidden_states else None
146
+ all_self_attns = () if output_attentions else None
147
+ next_decoder_cache = None
148
+
149
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `layers`.
150
+ for i, decoder_layer in enumerate(self.layers):
151
+ if output_hidden_states:
152
+ all_hidden_states += (hidden_states,)
153
+
154
+ if self.gradient_checkpointing and self.training:
155
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute
156
+ # `_gradient_checkpointing_func`.
157
+ layer_outputs = self._gradient_checkpointing_func(
158
+ decoder_layer.__call__,
159
+ hidden_states,
160
+ causal_mask,
161
+ position_ids,
162
+ past_key_values,
163
+ output_attentions,
164
+ use_cache,
165
+ cache_position,
166
+ )
167
+ else:
168
+ layer_outputs = decoder_layer(
169
+ hidden_states,
170
+ attention_mask=causal_mask,
171
+ position_ids=position_ids,
172
+ past_key_value=past_key_values,
173
+ output_attentions=output_attentions,
174
+ use_cache=use_cache,
175
+ cache_position=cache_position,
176
+ )
177
+
178
+ hidden_states = layer_outputs[0]
179
+
180
+ if use_cache:
181
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
182
+
183
+ if output_attentions:
184
+ all_self_attns += (layer_outputs[1],)
185
+
186
+ # pyre-fixme[16]: `CambrianQwenModel` has no attribute `norm`.
187
+ hidden_states = self.norm(hidden_states)
188
+
189
+ # add hidden states from the last decoder layer
190
+ if output_hidden_states:
191
+ all_hidden_states += (hidden_states,)
192
+
193
+ next_cache = None
194
+ if use_cache:
195
+ next_cache = (
196
+ next_decoder_cache.to_legacy_cache()
197
+ if use_legacy_cache
198
+ else next_decoder_cache
199
+ )
200
+
201
+ if not return_dict:
202
+ return tuple(
203
+ v
204
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
205
+ if v is not None
206
+ )
207
+ return BaseModelOutputWithPast(
208
+ last_hidden_state=hidden_states,
209
+ past_key_values=next_cache,
210
+ hidden_states=all_hidden_states,
211
+ attentions=all_self_attns,
212
+ )
213
+
214
+
215
+ class CambrianQwenForCausalLM(Qwen2ForCausalLM, CambrianMetaForCausalLM):
216
+ config_class = CambrianConfig
217
+
218
+ def __init__(self, config):
219
+ # super(Qwen2ForCausalLM, self).__init__(config)
220
+ Qwen2ForCausalLM.__init__(self, config)
221
+ config.model_type = "cambrian_qwen"
222
+ config.rope_scaling = None
223
+
224
+ self.model = CambrianQwenModel(config)
225
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
226
+ # Initialize weights and apply final processing
227
+ self.post_init()
228
+
229
+ def get_model(self):
230
+ return self.model
231
+
232
+ def forward(
233
+ self,
234
+ # pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
235
+ input_ids: torch.LongTensor = None,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
239
+ inputs_embeds: Optional[torch.FloatTensor] = None,
240
+ labels: Optional[torch.LongTensor] = None,
241
+ use_cache: Optional[bool] = None,
242
+ output_attentions: Optional[bool] = None,
243
+ output_hidden_states: Optional[bool] = None,
244
+ images: Optional[torch.FloatTensor] = None,
245
+ image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
246
+ image_sizes: Optional[List[List[int]]] = None,
247
+ return_dict: Optional[bool] = None,
248
+ modalities: Optional[List[str]] = ["image"],
249
+ dpo_forward: Optional[bool] = False,
250
+ cache_position=None,
251
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
252
+
253
+ input_image_features = None
254
+ highres_image_features = None
255
+ frame_split_sizes = None
256
+
257
+ if inputs_embeds is None:
258
+ (
259
+ input_ids,
260
+ position_ids,
261
+ attention_mask,
262
+ past_key_values,
263
+ inputs_embeds,
264
+ labels,
265
+ vision_tower_aux_feature_list,
266
+ vision_tower_aux_attention_masks_list,
267
+ final_vision_feature_size,
268
+ global_context_feature,
269
+ ) = self.prepare_inputs_labels_for_multimodal(
270
+ input_ids,
271
+ position_ids,
272
+ attention_mask,
273
+ past_key_values,
274
+ labels,
275
+ images,
276
+ image_aux_attention_masks_list,
277
+ image_sizes,
278
+ )
279
+
280
+ if dpo_forward:
281
+ # pyre-fixme[29]: `CambrianQwenModel` is not a function.
282
+ outputs = self.model(
283
+ input_ids=input_ids,
284
+ attention_mask=attention_mask,
285
+ position_ids=position_ids,
286
+ past_key_values=past_key_values,
287
+ inputs_embeds=inputs_embeds,
288
+ use_cache=use_cache,
289
+ output_attentions=output_attentions,
290
+ output_hidden_states=output_hidden_states,
291
+ return_dict=return_dict,
292
+ )
293
+
294
+ hidden_states = outputs[0]
295
+ logits = self.lm_head(hidden_states)
296
+ return logits, labels
297
+
298
+ else:
299
+ if hasattr(self, "vision_tower_aux_feature_list"):
300
+ # pyre-fixme[29]: `CambrianQwenModel` is not a function.
301
+ outputs = self.model(
302
+ input_ids=input_ids,
303
+ attention_mask=attention_mask,
304
+ position_ids=position_ids,
305
+ past_key_values=past_key_values,
306
+ inputs_embeds=inputs_embeds,
307
+ use_cache=use_cache,
308
+ output_attentions=output_attentions,
309
+ output_hidden_states=output_hidden_states,
310
+ return_dict=return_dict,
311
+ vision_tower_aux_feature_list=(
312
+ # pyre-fixme[61]: `vision_tower_aux_feature_list` is
313
+ # undefined, or not always defined.
314
+ vision_tower_aux_feature_list
315
+ if inputs_embeds is None
316
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
317
+ # `vision_tower_aux_feature_list`.
318
+ else self.vision_tower_aux_feature_list
319
+ ),
320
+ vision_tower_aux_attention_masks_list=(
321
+ # pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
322
+ # undefined, or not always defined.
323
+ vision_tower_aux_attention_masks_list
324
+ if inputs_embeds is None
325
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
326
+ # `vision_tower_aux_attention_masks_list`.
327
+ else self.vision_tower_aux_attention_masks_list
328
+ ),
329
+ final_vision_feature_size=(
330
+ # pyre-fixme[61]: `final_vision_feature_size` is undefined,
331
+ # or not always defined.
332
+ final_vision_feature_size
333
+ if inputs_embeds is None
334
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
335
+ # `final_vision_feature_size`.
336
+ else self.final_vision_feature_size
337
+ ),
338
+ global_context_feature=(
339
+ # pyre-fixme[61]: `global_context_feature` is undefined, or
340
+ # not always defined.
341
+ global_context_feature
342
+ if inputs_embeds is None
343
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
344
+ # `global_context_feature`.
345
+ else self.global_context_feature
346
+ ),
347
+ )
348
+ else:
349
+ # pyre-fixme[29]: `CambrianQwenModel` is not a function.
350
+ outputs = self.model(
351
+ input_ids=input_ids,
352
+ attention_mask=attention_mask,
353
+ position_ids=position_ids,
354
+ past_key_values=past_key_values,
355
+ inputs_embeds=inputs_embeds,
356
+ use_cache=use_cache,
357
+ output_attentions=output_attentions,
358
+ output_hidden_states=output_hidden_states,
359
+ return_dict=return_dict,
360
+ # final_vision_feature_size=final_vision_feature_size,
361
+ )
362
+
363
+ hidden_states = outputs[0]
364
+ logits = self.lm_head(hidden_states)
365
+ logits = logits.float()
366
+
367
+ loss = None
368
+ if labels is not None:
369
+ # Shift so that tokens < n predict n
370
+ shift_logits = logits[..., :-1, :].contiguous()
371
+ shift_labels = labels[..., 1:].contiguous()
372
+ # Flatten the tokens
373
+ loss_fct = CrossEntropyLoss()
374
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute `config`.
375
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
376
+ shift_labels = shift_labels.view(-1)
377
+ # Enable model parallelism
378
+ shift_labels = shift_labels.to(shift_logits.device)
379
+ loss = loss_fct(shift_logits, shift_labels)
380
+
381
+ if not return_dict:
382
+ output = (logits,) + outputs[1:]
383
+ return (loss,) + output if loss is not None else output
384
+
385
+ return CausalLMOutputWithPast(
386
+ loss=loss,
387
+ logits=logits,
388
+ past_key_values=outputs.past_key_values,
389
+ hidden_states=outputs.hidden_states,
390
+ attentions=outputs.attentions,
391
+ )
392
+
393
+ @torch.no_grad()
394
+ def generate(
395
+ self,
396
+ inputs: Optional[torch.Tensor] = None,
397
+ images: Optional[torch.Tensor] = None,
398
+ image_sizes: Optional[torch.Tensor] = None,
399
+ **kwargs,
400
+ ) -> Union[GenerateOutput, torch.LongTensor]:
401
+ position_ids = kwargs.pop("position_ids", None)
402
+ attention_mask = kwargs.pop("attention_mask", None)
403
+ if "inputs_embeds" in kwargs:
404
+ raise NotImplementedError("`inputs_embeds` is not supported")
405
+
406
+ if images is not None:
407
+ (
408
+ inputs,
409
+ position_ids,
410
+ attention_mask,
411
+ _,
412
+ inputs_embeds,
413
+ _,
414
+ vision_tower_aux_feature_list,
415
+ vision_tower_aux_attention_masks_list,
416
+ final_vision_feature_size,
417
+ global_context_feature,
418
+ ) = self.prepare_inputs_labels_for_multimodal(
419
+ inputs,
420
+ position_ids,
421
+ attention_mask,
422
+ None,
423
+ None,
424
+ images,
425
+ image_sizes=image_sizes,
426
+ )
427
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
428
+ # `vision_tower_aux_feature_list`.
429
+ self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
430
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
431
+ # `vision_tower_aux_attention_masks_list`.
432
+ self.vision_tower_aux_attention_masks_list = (
433
+ vision_tower_aux_attention_masks_list
434
+ )
435
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
436
+ # `final_vision_feature_size`.
437
+ self.final_vision_feature_size = final_vision_feature_size
438
+ # pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
439
+ # `global_context_feature`.
440
+ self.global_context_feature = global_context_feature
441
+ else:
442
+ inputs_embeds = self.get_model().embed_tokens(inputs)
443
+
444
+ # pyre-fixme[16]: `Qwen2ForCausalLM` has no attribute `generate`.
445
+ return super().generate(
446
+ position_ids=position_ids,
447
+ attention_mask=attention_mask,
448
+ inputs_embeds=inputs_embeds,
449
+ **kwargs,
450
+ )
451
+
452
+ def prepare_inputs_for_generation(
453
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
454
+ ):
455
+ images = kwargs.pop("images", None)
456
+ image_sizes = kwargs.pop("image_sizes", None)
457
+ inputs = super().prepare_inputs_for_generation(
458
+ input_ids,
459
+ past_key_values=past_key_values,
460
+ inputs_embeds=inputs_embeds,
461
+ **kwargs,
462
+ )
463
+ if images is not None:
464
+ inputs["images"] = images
465
+ if image_sizes is not None:
466
+ inputs["image_sizes"] = image_sizes
467
+ return inputs
468
+
469
+
470
+ AutoConfig.register("cambrian_qwen", CambrianConfig)
471
+ AutoModelForCausalLM.register(CambrianConfig, CambrianQwenForCausalLM)
multimodal_encoder_builder.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoImageProcessor, Dinov2Config, Dinov2Model, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
6
+ from abc import ABC, abstractmethod
7
+ import torch.nn as nn
8
+
9
+
10
+ class ProcessorWrapper:
11
+ def __init__(
12
+ self,
13
+ transform,
14
+ height=378,
15
+ width=378,
16
+ image_mean=[0.48145466, 0.4578275, 0.40821073],
17
+ ):
18
+ self._crop_size = {
19
+ "height": height,
20
+ "width": width,
21
+ }
22
+ self._transforms = transform
23
+ # print(transform)
24
+ self.image_mean = image_mean
25
+
26
+ @property
27
+ def crop_size(self):
28
+ return self._crop_size
29
+
30
+ def preprocess(self, image, return_tensors="pt"):
31
+ # Ensure image is a PIL Image
32
+ output = {}
33
+ output["pixel_values"] = [self._transforms(image)]
34
+ return output
35
+
36
+
37
+ class BaseVisionTower(nn.Module):
38
+ def __init__(self, vision_tower_name, args, delay_load=False):
39
+ super().__init__()
40
+
41
+ self.is_loaded = False
42
+ self.args = args
43
+
44
+ self.vision_tower_name = vision_tower_name
45
+ self.select_layer = args.mm_vision_select_layer
46
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
47
+ self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
48
+ self.delay_load = delay_load
49
+
50
+ @abstractmethod
51
+ def load_model(self, device_map=None):
52
+ raise NotImplementedError("Subclasses must implement load_model")
53
+
54
+ @abstractmethod
55
+ def _forward(self, images):
56
+ raise NotImplementedError("Subclasses must implement forward")
57
+
58
+ def forward(self, images):
59
+ if type(images) is list:
60
+ image_features = [self._forward(image.unsqueeze(0)) for image in images]
61
+ else:
62
+ image_features = self._forward(images)
63
+
64
+ return image_features
65
+
66
+ @property
67
+ def dummy_feature(self):
68
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
69
+
70
+ @property
71
+ def dtype(self):
72
+ # Dynamically infer the dtype from the first parameter, if not explicitly specified
73
+ if hasattr(self.vision_tower, "dtype"):
74
+ return self.vision_tower.dtype
75
+ else:
76
+ params = list(self.vision_tower.parameters())
77
+ return (
78
+ params[0].dtype if len(params) > 0 else torch.float32
79
+ ) # Default to torch.float32 if no parameters
80
+
81
+ @property
82
+ def device(self):
83
+ # Dynamically infer the device from the first parameter, if not explicitly specified
84
+ if hasattr(self.vision_tower, "device"):
85
+ return self.vision_tower.device
86
+ else:
87
+ params = list(self.vision_tower.parameters())
88
+ return (
89
+ params[0].device if len(params) > 0 else torch.device("cpu")
90
+ ) # Default to CPU if no parameters
91
+
92
+ @property
93
+ def config(self):
94
+ if self.is_loaded:
95
+ return self.vision_tower.config
96
+ else:
97
+ return self.cfg_only
98
+
99
+ @property
100
+ def hidden_size(self):
101
+ try:
102
+ return self.config.hidden_size
103
+ except:
104
+ return self._hidden_size
105
+
106
+ @property
107
+ def image_size(self): # resolution
108
+ # return self.config.image_size
109
+ try:
110
+ return self.config.image_size
111
+ except:
112
+ return self._image_size
113
+
114
+ @property
115
+ def patch_size(self):
116
+ # return self.config.patch_size
117
+ try:
118
+ return self.config.patch_size
119
+ except:
120
+ return self._patch_size
121
+
122
+ @property
123
+ def num_patches_per_side(self):
124
+ if self._interp_size is not None:
125
+ return int(self._interp_size**0.5)
126
+ try:
127
+ return self.image_size // self.patch_size
128
+ except:
129
+ return self._num_patches_per_side
130
+
131
+ @property
132
+ def num_patches(self):
133
+ if self._interp_size is not None:
134
+ return self._interp_size
135
+ try:
136
+ return self.num_patches_per_side**2
137
+ except:
138
+ return self._num_patches
139
+
140
+
141
+ class DinoVisionTower(BaseVisionTower):
142
+ def __init__(self, vision_tower, args, delay_load=False):
143
+ super(DinoVisionTower, self).__init__(vision_tower, args, delay_load)
144
+
145
+ model_path = "facebook/dinov2-giant"
146
+ base_model_name, res, interp = model_path, 378, 576
147
+ self._vision_tower_name = vision_tower
148
+ self.vision_tower_name = base_model_name
149
+ self._image_size = res
150
+ self._interp_size = interp
151
+ self._patch_size = 14 # default patch size
152
+
153
+ if not self.delay_load:
154
+ self.load_model()
155
+ else:
156
+ self.cfg_only = Dinov2Config.from_pretrained(self.vision_tower_name)
157
+
158
+ def load_model(self, device_map=None):
159
+
160
+ self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
161
+ """ValueError: Dinov2Model does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute."""
162
+ self.vision_tower._no_split_modules = ["Dinov2SwiGLUFFN"]
163
+
164
+ _image_size = self.vision_tower.config.image_size
165
+ if self._image_size is None:
166
+ self._image_size = _image_size
167
+
168
+ # increase shortest edge to prevent edge case crops
169
+ default_shortest_ratio = 8 / 7 # 224/256
170
+ # shortest_edge = int(default_shortest_ratio * self._image_size)
171
+ shortest_edge = self._image_size
172
+
173
+ processor = AutoImageProcessor.from_pretrained(
174
+ self.vision_tower_name,
175
+ crop_size=dict(height=self._image_size, width=self._image_size),
176
+ size=dict(shortest_edge=shortest_edge),
177
+ )
178
+ self.image_processor = processor
179
+
180
+ # Assign the output channels of the projection convolution as the hidden size
181
+ self._hidden_size = (
182
+ self.vision_tower.embeddings.patch_embeddings.projection.out_channels
183
+ )
184
+ # Assign the first value of the stride of the projection convolution as the patch size
185
+ self._patch_size = (
186
+ self.vision_tower.embeddings.patch_embeddings.projection.stride[0]
187
+ )
188
+
189
+ # print(self._hidden_size, self._patch_size)
190
+
191
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
192
+ self.is_loaded = True
193
+
194
+ @property
195
+ def image_size(self):
196
+ return self._image_size
197
+
198
+ def feature_select(self, outputs):
199
+ sequence_output = outputs[
200
+ "last_hidden_state"
201
+ ] # batch_size, sequence_length, hidden_size
202
+
203
+ if self.select_feature == "cls_patch":
204
+ image_features = sequence_output
205
+ elif self.select_feature == "patch":
206
+ image_features = sequence_output[:, 1:]
207
+ elif self.select_feature == "cls":
208
+ image_features = sequence_output[:, 0]
209
+ else:
210
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
211
+ return image_features
212
+
213
+ def interpolate(self, image_features):
214
+ if self._interp_size is None:
215
+ return image_features
216
+
217
+ b, num_tokens, dim = image_features.shape
218
+
219
+ if num_tokens != self.num_patches:
220
+ target_h = target_w = int(self._interp_size**0.5)
221
+ h = w = int(num_tokens**0.5)
222
+
223
+ image_features = image_features.view(b, h, w, dim)
224
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
225
+
226
+ image_features = F.interpolate(
227
+ image_features.to(torch.float32),
228
+ size=(target_h, target_w),
229
+ mode="bilinear",
230
+ align_corners=False,
231
+ ).to(image_features.dtype)
232
+
233
+ # Permute the dimensions back to (b, target_h, target_w, dim)
234
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
235
+
236
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
237
+ image_features = image_features.flatten(1, 2)
238
+
239
+ return image_features
240
+
241
+ def _forward(self, images):
242
+ # logger.warning(f"images shape: {images.shape}")
243
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
244
+ image_forward_outs = self.vision_tower.forward(
245
+ images.to(device=self.device, dtype=self.dtype)
246
+ )
247
+ # logger.warning(f"image_forward_outs shape: {image_forward_outs['last_hidden_state'].shape}")
248
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
249
+ # logger.warning(f"image_features shape: {image_features.shape}")
250
+ interp_features = self.interpolate(image_features)
251
+ # logger.warning(f"interp_features shape: {interp_features.shape}")
252
+ return interp_features
253
+
254
+ @property
255
+ def num_patches_per_side(self):
256
+ return int(self.num_patches**0.5)
257
+
258
+ @property
259
+ def num_patches(self):
260
+ if self._interp_size is None:
261
+ return (self._image_size // self._patch_size) ** 2
262
+ else:
263
+ return self._interp_size
264
+
265
+
266
+ # from .siglip_encoder import SiglipVisionTower
267
+ class SiglipVisionTower(BaseVisionTower):
268
+ def __init__(self, vision_tower_name, args, delay_load=False):
269
+ super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
270
+
271
+ model_path = "google/siglip-so400m-patch14-384"
272
+ base_model_name, res, interp = model_path, 384, 576
273
+ self.vision_tower_name = base_model_name
274
+ self._image_size = res if res is not None else 512
275
+ self._interp_size = interp
276
+ if not self.delay_load:
277
+ self.load_model()
278
+ elif self.unfreeze_mm_vision_tower:
279
+ self.load_model()
280
+ else:
281
+ self._hidden_size = 1152
282
+
283
+ def load_model(self, device_map=None):
284
+ self.vision_model = "siglip"
285
+ # clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
286
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
287
+
288
+ # self.vision_tower = clip_model.visual.trunk
289
+ self.vision_tower.output_tokens = True
290
+
291
+ self._hidden_size = self.vision_tower.config.hidden_size
292
+ self._image_size = self.vision_tower.config.image_size
293
+ self._patch_size = self.vision_tower.config.patch_size
294
+ self.image_processor = SiglipImageProcessor.from_pretrained(
295
+ self.vision_tower_name
296
+ )
297
+
298
+ self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
299
+ self.is_loaded = True
300
+
301
+ def interpolate(self, image_features):
302
+ if self._interp_size is None:
303
+ return image_features
304
+
305
+ b, num_tokens, dim = image_features.shape
306
+
307
+ if num_tokens != self.num_patches:
308
+ target_h = target_w = int(self._interp_size**0.5)
309
+ h = w = int(num_tokens**0.5)
310
+
311
+ image_features = image_features.view(b, h, w, dim)
312
+ image_features = image_features.permute(0, 3, 1, 2).contiguous()
313
+
314
+ image_features = F.interpolate(
315
+ image_features.to(torch.float32),
316
+ size=(target_h, target_w),
317
+ mode="bilinear",
318
+ align_corners=False,
319
+ ).to(image_features.dtype)
320
+
321
+ # Permute the dimensions back to (b, target_h, target_w, dim)
322
+ image_features = image_features.permute(0, 2, 3, 1).contiguous()
323
+
324
+ # Flatten the spatial dimensions (target_h, target_w) into a single dimension
325
+ image_features = image_features.flatten(1, 2)
326
+
327
+ return image_features
328
+
329
+ def _forward(self, images, interpolate_token=576):
330
+ with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
331
+ image_features = self.vision_tower.forward(
332
+ images.to(device=self.device, dtype=self.dtype),
333
+ output_hidden_states=True,
334
+ ).hidden_states[-1]
335
+ interp_features = self.interpolate(image_features)
336
+ return interp_features
337
+
338
+
339
+ def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
340
+ vision_tower_aux_name_list = getattr(
341
+ vision_tower_cfg,
342
+ "mm_vision_tower_aux_list",
343
+ getattr(vision_tower_cfg, "vision_tower_aux_list", None),
344
+ )
345
+ vision_tower_aux_token_len_list = getattr(
346
+ vision_tower_cfg,
347
+ "mm_vision_tower_aux_token_len_list",
348
+ getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
349
+ )
350
+ vision_tower_aux_list = []
351
+ for vision_tower_aux_name, vision_tower_aux_token_len in zip(
352
+ vision_tower_aux_name_list, vision_tower_aux_token_len_list
353
+ ):
354
+ config = copy.deepcopy(vision_tower_cfg)
355
+ vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
356
+ if "siglip" in vision_tower_aux_name.lower():
357
+ vision_tower_aux_list.append(
358
+ SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
359
+ )
360
+
361
+ # SSL-based Vision Towers
362
+ elif "dinov2" in vision_tower_aux_name.lower():
363
+ vision_tower_aux_list.append(
364
+ DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
365
+ )
366
+ else:
367
+ raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
368
+ return vision_tower_aux_list
multimodal_projector_builder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pyre-unsafe
2
+ import re
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class IdentityMap(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x, *args, **kwargs):
12
+ return x
13
+
14
+ @property
15
+ def config(self):
16
+ return {"mm_projector_type": "identity"}
17
+
18
+
19
+ class SimpleResBlock(nn.Module):
20
+ def __init__(self, channels):
21
+ super().__init__()
22
+ self.pre_norm = nn.LayerNorm(channels)
23
+
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
26
+ )
27
+
28
+ def forward(self, x):
29
+ x = self.pre_norm(x)
30
+ return x + self.proj(x)
31
+
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, "mm_projector_type", "linear")
35
+ config.mm_hidden_size = 256
36
+
37
+ if projector_type == "linear":
38
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
39
+
40
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
41
+ if mlp_gelu_match:
42
+ mlp_depth = int(mlp_gelu_match.group(1))
43
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44
+ for _ in range(1, mlp_depth):
45
+ modules.append(nn.GELU())
46
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47
+ return nn.Sequential(*modules)
48
+
49
+ if projector_type == "identity":
50
+ return IdentityMap()
51
+
52
+ raise ValueError(f"Unknown projector type: {projector_type}")
special_tokens_map.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>"
5
+ ],
6
+ "eos_token": {
7
+ "content": "<|im_end|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false
12
+ },
13
+ "pad_token": {
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ }
20
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ },
28
+ "151646": {
29
+ "content": "<image>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false,
34
+ "special": true
35
+ }
36
+ },
37
+ "additional_special_tokens": [
38
+ "<|im_start|>",
39
+ "<|im_end|>"
40
+ ],
41
+ "bos_token": null,
42
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
43
+ "clean_up_tokenization_spaces": false,
44
+ "eos_token": "<|im_end|>",
45
+ "errors": "replace",
46
+ "model_max_length": 32768,
47
+ "pad_token": "<|endoftext|>",
48
+ "padding_side": "right",
49
+ "processor_class": "LlavaProcessor",
50
+ "split_special_tokens": false,
51
+ "tokenizer_class": "Qwen2Tokenizer",
52
+ "unk_token": null
53
+ }
vision_sampler.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+
8
+
9
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
10
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
11
+ """
12
+ grid_size: int of the grid height and width
13
+ return:
14
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
15
+ """
16
+ grid_h = np.arange(grid_size, dtype=np.float32)
17
+ grid_w = np.arange(grid_size, dtype=np.float32)
18
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
19
+ grid = np.stack(grid, axis=0)
20
+
21
+ grid = grid.reshape([2, 1, grid_size, grid_size])
22
+
23
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
24
+ if cls_token:
25
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
26
+ return pos_embed
27
+
28
+
29
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
30
+ assert embed_dim % 2 == 0
31
+
32
+ # use half of dimensions to encode grid_h
33
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
34
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
35
+
36
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
37
+ return emb
38
+
39
+
40
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
41
+ """
42
+ embed_dim: output dimension for each position
43
+ pos: a list of positions to be encoded: size (M,)
44
+ out: (M, D)
45
+ """
46
+ assert embed_dim % 2 == 0
47
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
48
+ omega /= embed_dim / 2.0
49
+ omega = 1.0 / 10000**omega # (D/2,)
50
+
51
+ pos = pos.reshape(-1) # (M,)
52
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
53
+
54
+ emb_sin = np.sin(out) # (M, D/2)
55
+ emb_cos = np.cos(out) # (M, D/2)
56
+
57
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
58
+ return emb
59
+
60
+
61
+ class CrossAttention(nn.Module):
62
+
63
+ def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
64
+ super().__init__()
65
+ self.hidden_dim = hidden_dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = self.hidden_dim // self.num_heads
68
+
69
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
70
+ raise ValueError(
71
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
72
+ f" and `num_heads`: {self.num_heads})."
73
+ )
74
+
75
+ self.q_proj = nn.Sequential(
76
+ nn.LayerNorm(q_dim),
77
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
78
+ )
79
+ self.k_proj = nn.Sequential(
80
+ nn.LayerNorm(kv_dim),
81
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
82
+ )
83
+ self.v_proj = nn.Sequential(
84
+ nn.LayerNorm(kv_dim),
85
+ nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
86
+ )
87
+ self.o_proj = nn.Linear(
88
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
89
+ )
90
+
91
+ def forward(self, vision_latents, queries, attention_mask):
92
+
93
+ bsz, q_len, _ = queries.size()
94
+ bsz, v_len, _ = vision_latents.size()
95
+
96
+ query_states = self.q_proj(queries)
97
+ key_states = self.k_proj(vision_latents)
98
+ value_states = self.v_proj(vision_latents)
99
+
100
+ query_states = query_states.view(
101
+ bsz, q_len, self.num_heads, self.head_dim
102
+ ).transpose(1, 2)
103
+ key_states = key_states.view(
104
+ bsz, v_len, self.num_heads, self.head_dim
105
+ ).transpose(1, 2)
106
+ value_states = value_states.view(
107
+ bsz, v_len, self.num_heads, self.head_dim
108
+ ).transpose(1, 2)
109
+
110
+ if attention_mask is not None:
111
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
112
+ raise ValueError(
113
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
114
+ )
115
+
116
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
117
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
118
+ if query_states.device.type == "cuda" and attention_mask is not None:
119
+ query_states = query_states.contiguous()
120
+ key_states = key_states.contiguous()
121
+ value_states = value_states.contiguous()
122
+
123
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
124
+ query_states,
125
+ key_states,
126
+ value_states,
127
+ attn_mask=attention_mask,
128
+ )
129
+
130
+ attn_output = attn_output.transpose(1, 2).contiguous()
131
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
132
+
133
+ attn_output = self.o_proj(attn_output)
134
+
135
+ return attn_output
136
+
137
+
138
+ class AggregationBlock(nn.Module):
139
+ def __init__(
140
+ self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
141
+ ):
142
+ super().__init__()
143
+ self.hidden_dim = hidden_dim
144
+ self.num_heads = num_heads
145
+ self.head_dim = self.hidden_dim // self.num_heads
146
+
147
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
148
+ raise ValueError(
149
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
150
+ f" and `num_heads`: {self.num_heads})."
151
+ )
152
+
153
+ self.attention = attention
154
+ if attention:
155
+ self.attention_layer = CrossAttention(
156
+ q_dim, kv_dim, hidden_dim, num_heads, attention_bias
157
+ )
158
+ else:
159
+ self.attention_layer = MLP(kv_dim, q_dim, q_dim)
160
+
161
+ def forward(self, vision_latents, queries, attention_mask):
162
+ if self.attention:
163
+ queries = self.attention_layer(vision_latents, queries, attention_mask)
164
+ else:
165
+ queries = self.attention_layer(vision_latents)
166
+
167
+ return queries
168
+
169
+
170
+ class MultiKVCrossAttention(nn.Module):
171
+
172
+ def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
173
+ super().__init__()
174
+
175
+ self.hidden_dim = hidden_dim
176
+ self.num_heads = num_heads
177
+ self.head_dim = self.hidden_dim // self.num_heads
178
+
179
+ if (self.head_dim * self.num_heads) != self.hidden_dim:
180
+ raise ValueError(
181
+ f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
182
+ f" and `num_heads`: {self.num_heads})."
183
+ )
184
+
185
+ self.q_proj = nn.Sequential(
186
+ nn.LayerNorm(q_dim),
187
+ nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
188
+ )
189
+ self.num_of_kvs = len(kv_dim_list)
190
+ for i, kv_dim in enumerate(kv_dim_list):
191
+ setattr(
192
+ self,
193
+ "k_proj_{}".format(i),
194
+ nn.Sequential(
195
+ nn.LayerNorm(kv_dim),
196
+ nn.Linear(
197
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
198
+ ),
199
+ ),
200
+ )
201
+ setattr(
202
+ self,
203
+ "v_proj_{}".format(i),
204
+ nn.Sequential(
205
+ nn.LayerNorm(kv_dim),
206
+ nn.Linear(
207
+ kv_dim, self.num_heads * self.head_dim, bias=attention_bias
208
+ ),
209
+ ),
210
+ )
211
+ self.o_proj = nn.Linear(
212
+ self.num_heads * self.head_dim, q_dim, bias=attention_bias
213
+ )
214
+
215
+ def forward(
216
+ self,
217
+ queries,
218
+ *vision_latents_attention_mask_list,
219
+ ):
220
+
221
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
222
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
223
+
224
+ bsz, q_len, _ = queries.size()
225
+
226
+ query_states = self.q_proj(queries)
227
+ key_states = torch.cat(
228
+ [
229
+ getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
230
+ for i in range(self.num_of_kvs)
231
+ ],
232
+ dim=1,
233
+ )
234
+ value_states = torch.cat(
235
+ [
236
+ getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
237
+ for i in range(self.num_of_kvs)
238
+ ],
239
+ dim=1,
240
+ )
241
+
242
+ v_len = key_states.shape[1]
243
+
244
+ query_states = query_states.view(
245
+ bsz, q_len, self.num_heads, self.head_dim
246
+ ).transpose(1, 2)
247
+ key_states = key_states.view(
248
+ bsz, v_len, self.num_heads, self.head_dim
249
+ ).transpose(1, 2)
250
+ value_states = value_states.view(
251
+ bsz, v_len, self.num_heads, self.head_dim
252
+ ).transpose(1, 2)
253
+
254
+ # if kv_weight is not None:
255
+ # kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
256
+
257
+ attention_mask = torch.cat(attention_mask_list, dim=-1)
258
+
259
+ if attention_mask is not None:
260
+ if attention_mask.size() != (bsz, 1, q_len, v_len):
261
+ raise ValueError(
262
+ f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
263
+ )
264
+
265
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
266
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
267
+ if query_states.device.type == "cuda" and attention_mask is not None:
268
+ query_states = query_states.contiguous()
269
+ key_states = key_states.contiguous()
270
+ value_states = value_states.contiguous()
271
+
272
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
273
+ query_states,
274
+ key_states,
275
+ value_states,
276
+ attn_mask=attention_mask,
277
+ )
278
+ # attn_output = spda(
279
+ # query_states,
280
+ # key_states,
281
+ # value_states,
282
+ # attn_mask=attention_mask,
283
+ # additional_score=kv_weight
284
+ # )
285
+
286
+ attn_output = attn_output.transpose(1, 2).contiguous()
287
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
288
+
289
+ attn_output = self.o_proj(attn_output)
290
+
291
+ return attn_output
292
+
293
+
294
+ class MLP(nn.Module):
295
+ def __init__(self, d_in, d_hidden, d_out):
296
+ super().__init__()
297
+ self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
298
+ self.act = nn.GELU()
299
+ self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
300
+
301
+ def forward(self, x):
302
+ return self.linear_2(self.act(self.linear_1(x)))
303
+
304
+
305
+ class VisionCrossAttentionLayer(nn.Module):
306
+ def __init__(
307
+ self,
308
+ q_dim,
309
+ context_dim,
310
+ kv_dim_list,
311
+ kv_size_list,
312
+ hidden_dim=1024,
313
+ layer_idx=0,
314
+ ):
315
+ super().__init__()
316
+ num_heads = 16
317
+ self.num_of_kvs = len(kv_dim_list)
318
+
319
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
320
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
321
+ # if self.num_of_kvs > 1:
322
+ # self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
323
+ # self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
324
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
325
+
326
+ self.norm = nn.LayerNorm(hidden_dim)
327
+
328
+ self.cross_attn = MultiKVCrossAttention(
329
+ hidden_dim, kv_dim_list, hidden_dim, num_heads
330
+ )
331
+ self.kv_size_list = kv_size_list
332
+ for i, kv_size in enumerate(kv_size_list):
333
+ if kv_size > 1:
334
+ setattr(
335
+ self,
336
+ "pos_embed_{}".format(i),
337
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
338
+ )
339
+ # self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
340
+
341
+ def forward(
342
+ self,
343
+ queries,
344
+ context_feature,
345
+ *vision_latents_attention_mask_list,
346
+ ) -> torch.FloatTensor:
347
+
348
+ residual = queries
349
+ # queries = self.proj_in(queries)
350
+ context_feature = self.proj_context(context_feature)
351
+ # queries = queries + context_feature
352
+ queries = torch.cat([queries, context_feature], -1)
353
+
354
+ # if self.num_of_kvs > 1:
355
+ # kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
356
+ # kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
357
+ # kv_weight = kv_weight.softmax(-1)
358
+ # kv_number_list = [size**2 for size in self.kv_size_list]
359
+ # kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
360
+ # else:
361
+ # kv_weight = None
362
+
363
+ queries = self.proj_in(queries)
364
+
365
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
366
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
367
+
368
+ attention_mask_list_reshaped = []
369
+ if attention_mask_list is not None:
370
+ for attention_mask in attention_mask_list:
371
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
372
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
373
+ attention_mask_list_reshaped.append(attention_mask)
374
+
375
+ vision_latents_pos_list = []
376
+ for i, vision_latents in enumerate(vision_latents_list):
377
+ if vision_latents.shape[1] > 1:
378
+ vision_latents_pos_list.append(
379
+ vision_latents
380
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
381
+ vision_latents.dtype
382
+ )
383
+ )
384
+ else:
385
+ vision_latents_pos_list.append(vision_latents)
386
+
387
+ # Cross Attention
388
+ attention_output = self.cross_attn(
389
+ queries, *vision_latents_pos_list, *attention_mask_list_reshaped
390
+ )
391
+
392
+ # attention_output = (attention_output * combination_weight).sum(2)
393
+ queries = queries + attention_output
394
+
395
+ queries = self.norm(queries)
396
+
397
+ queries = self.proj_out(queries)
398
+
399
+ queries = queries + residual
400
+
401
+ return queries
402
+
403
+
404
+ class VisionAggregationLayer(nn.Module):
405
+ def __init__(
406
+ self,
407
+ q_dim,
408
+ context_dim,
409
+ kv_dim_list,
410
+ kv_size_list,
411
+ hidden_dim=1024,
412
+ layer_idx=0,
413
+ ):
414
+ super().__init__()
415
+ num_heads = 16
416
+ self.num_of_kvs = len(kv_dim_list)
417
+
418
+ self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
419
+ self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
420
+
421
+ self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
422
+
423
+ self.norm = nn.LayerNorm(hidden_dim)
424
+
425
+ if self.num_of_kvs > 1:
426
+ self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
427
+
428
+ for i, kv_size in enumerate(kv_size_list):
429
+ if kv_size > 1:
430
+ setattr(
431
+ self,
432
+ "pos_embed_{}".format(i),
433
+ nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
434
+ )
435
+ setattr(
436
+ self,
437
+ "aggregate_{}".format(i),
438
+ AggregationBlock(
439
+ True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
440
+ ),
441
+ )
442
+ else:
443
+ setattr(
444
+ self,
445
+ "aggregate_{}".format(i),
446
+ AggregationBlock(
447
+ False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
448
+ ),
449
+ )
450
+
451
+ def forward(
452
+ self,
453
+ queries,
454
+ context_feature,
455
+ *vision_latents_attention_mask_list,
456
+ ) -> torch.FloatTensor:
457
+
458
+ residual = queries
459
+ # queries = self.proj_in(queries)
460
+ context_feature = self.proj_context(context_feature)
461
+ # queries = queries + context_feature
462
+ queries = torch.cat([queries, context_feature], -1)
463
+
464
+ if self.num_of_kvs > 1:
465
+ combination_weight = self.weight_mlp(queries).softmax(
466
+ -1
467
+ ) # B * 1 * num_tower
468
+ combination_weight = combination_weight.unsqueeze(-1)
469
+ else:
470
+ combination_weight = 1
471
+
472
+ queries = self.proj_in(queries)
473
+
474
+ vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
475
+ attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
476
+
477
+ attention_mask_list_reshaped = []
478
+ if attention_mask_list is not None:
479
+ for attention_mask in attention_mask_list:
480
+ attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
481
+ attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
482
+ attention_mask_list_reshaped.append(attention_mask)
483
+
484
+ vision_latents_pos_list = []
485
+ for i, vision_latents in enumerate(vision_latents_list):
486
+ if vision_latents.shape[1] > 1:
487
+ vision_latents_pos_list.append(
488
+ vision_latents
489
+ + getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
490
+ vision_latents.dtype
491
+ )
492
+ )
493
+ else:
494
+ vision_latents_pos_list.append(vision_latents)
495
+
496
+ aggregated_vision_latents_list = []
497
+ for i, (vision_latents, attention_mask) in enumerate(
498
+ zip(vision_latents_pos_list, attention_mask_list_reshaped)
499
+ ):
500
+ aggregated_vision_latents_list.append(
501
+ getattr(self, "aggregate_{}".format(i))(
502
+ vision_latents, queries, attention_mask
503
+ )
504
+ )
505
+
506
+ aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
507
+
508
+ queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
509
+
510
+ queries = self.norm(queries)
511
+
512
+ queries = self.proj_out(queries)
513
+
514
+ queries = queries + residual
515
+
516
+ return queries
517
+
518
+
519
+ class VisionTokenSampler(nn.Module):
520
+ def __init__(
521
+ self,
522
+ q_dim,
523
+ context_dim,
524
+ kv_dim_list,
525
+ kv_size_list,
526
+ vision_hidden_size,
527
+ num_of_layers=1,
528
+ layer_type="joint",
529
+ ):
530
+ super().__init__()
531
+ assert layer_type in ["joint", "sep"]
532
+ if layer_type == "joint":
533
+ self.layers = nn.ModuleList(
534
+ [
535
+ VisionCrossAttentionLayer(
536
+ q_dim,
537
+ context_dim,
538
+ kv_dim_list,
539
+ kv_size_list,
540
+ vision_hidden_size,
541
+ idx,
542
+ )
543
+ for idx in range(num_of_layers)
544
+ ]
545
+ )
546
+ else:
547
+ self.layers = nn.ModuleList(
548
+ [
549
+ VisionAggregationLayer(
550
+ q_dim,
551
+ context_dim,
552
+ kv_dim_list,
553
+ kv_size_list,
554
+ vision_hidden_size,
555
+ idx,
556
+ )
557
+ for idx in range(num_of_layers)
558
+ ]
559
+ )
560
+
561
+ def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
562
+ for layer in self.layers:
563
+ queries = layer(
564
+ queries, context_feature, *vision_latents_attention_mask_list
565
+ )
566
+ return queries
vocab.json ADDED
The diff for this file is too large to render. See raw diff