jadechoghari
commited on
Commit
•
2b42a47
1
Parent(s):
52a563c
add initial files
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- README.md +6 -0
- config.json +89 -0
- generation_config.json +14 -0
- merges.txt +0 -0
- modeling.py +471 -0
- multimodal_encoder_builder.py +368 -0
- multimodal_projector_builder.py +52 -0
- pytorch_model.bin +3 -0
- pytorch_model.bin.1 +3 -0
- special_tokens_map.json +20 -0
- tokenizer.json +0 -0
- tokenizer_config.json +53 -0
- vision_sampler.py +566 -0
- vocab.json +0 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
pytorch_model.bin.1 filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
datasets:
|
3 |
+
- shenxq/OneVision
|
4 |
+
base_model:
|
5 |
+
- Qwen/Qwen2-7B-Instruct
|
6 |
+
---
|
config.json
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "jadechoghari/LongVU_Qwen2_7B_img",
|
3 |
+
"architectures": [
|
4 |
+
"CambrianQwenForCausalLM"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "modeling.CambrianConfig",
|
8 |
+
"AutoModel": "modeling.CambrianQwenForCausalLM",
|
9 |
+
"AutoModelForCausalLM": "modeling.CambrianQwenForCausalLM"
|
10 |
+
},
|
11 |
+
"attention_bias": false,
|
12 |
+
"attention_dropout": 0.0,
|
13 |
+
"bos_token_id": 151643,
|
14 |
+
"connect_layer": 2,
|
15 |
+
"connector_depth": 3,
|
16 |
+
"connector_only": true,
|
17 |
+
"dino_threshold": 0.83,
|
18 |
+
"drop_threshold": 0.8,
|
19 |
+
"eos_token_id": 151645,
|
20 |
+
"frame_pos": false,
|
21 |
+
"freeze_mm_mlp_adapter": false,
|
22 |
+
"hidden_act": "silu",
|
23 |
+
"hidden_size": 3584,
|
24 |
+
"highres": false,
|
25 |
+
"highres_connect": false,
|
26 |
+
"image_aspect_ratio": "pad",
|
27 |
+
"image_position": 91,
|
28 |
+
"image_token_len": 576,
|
29 |
+
"initializer_range": 0.02,
|
30 |
+
"intermediate_size": 18944,
|
31 |
+
"is_image_newline": true,
|
32 |
+
"is_st_sampler": false,
|
33 |
+
"lowres_token": 8,
|
34 |
+
"max_position_embeddings": 32768,
|
35 |
+
"max_window_layers": 28,
|
36 |
+
"mm_patch_merge_type": "flat",
|
37 |
+
"mm_projector_lr": null,
|
38 |
+
"mm_projector_type": "sva",
|
39 |
+
"mm_use_im_patch_token": false,
|
40 |
+
"mm_use_im_start_end": false,
|
41 |
+
"mm_vision_sampler_lr": null,
|
42 |
+
"mm_vision_select_feature": "patch",
|
43 |
+
"mm_vision_select_layer": -2,
|
44 |
+
"mm_vision_tower_aux_list": [
|
45 |
+
"siglip/CLIP-ViT-SO400M-14-384",
|
46 |
+
"facebook/dinov2-giant-res378"
|
47 |
+
],
|
48 |
+
"mm_vision_tower_aux_token_len_list": [
|
49 |
+
576,
|
50 |
+
576
|
51 |
+
],
|
52 |
+
"mm_vision_tower_lr": null,
|
53 |
+
"model_type": "cambrian_qwen",
|
54 |
+
"num_attention_heads": 28,
|
55 |
+
"num_hidden_layers": 28,
|
56 |
+
"num_key_value_heads": 4,
|
57 |
+
"num_of_vision_sampler_layers": 10,
|
58 |
+
"num_query_group": 1,
|
59 |
+
"pretraining_tp": 1,
|
60 |
+
"query_num_list": [
|
61 |
+
576
|
62 |
+
],
|
63 |
+
"rms_norm_eps": 1e-06,
|
64 |
+
"rope_scaling": null,
|
65 |
+
"rope_theta": 1000000.0,
|
66 |
+
"sliding_window": null,
|
67 |
+
"spmd_debug": null,
|
68 |
+
"spmd_fsdp_sharding": null,
|
69 |
+
"spmd_mesh": null,
|
70 |
+
"start_of_vision_sampler_layers": 0,
|
71 |
+
"stride_of_vision_sampler_layers": 3,
|
72 |
+
"tie_word_embeddings": false,
|
73 |
+
"tokenizer_model_max_length": 8192,
|
74 |
+
"tokenizer_padding_side": "right",
|
75 |
+
"torch_dtype": "float32",
|
76 |
+
"transformers_version": "4.44.2",
|
77 |
+
"tune_mm_mlp_adapter": false,
|
78 |
+
"unfreeze_mm_vision_tower": false,
|
79 |
+
"use_cache": false,
|
80 |
+
"use_mm_proj": true,
|
81 |
+
"use_pos_skipping": false,
|
82 |
+
"use_sliding_window": false,
|
83 |
+
"vision_hidden_size": 1024,
|
84 |
+
"vision_tower_aux_token_len_list": [
|
85 |
+
576,
|
86 |
+
576
|
87 |
+
],
|
88 |
+
"vocab_size": 152064
|
89 |
+
}
|
generation_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 151643,
|
3 |
+
"do_sample": true,
|
4 |
+
"eos_token_id": [
|
5 |
+
151645,
|
6 |
+
151643
|
7 |
+
],
|
8 |
+
"pad_token_id": 151643,
|
9 |
+
"repetition_penalty": 1.05,
|
10 |
+
"temperature": 0.7,
|
11 |
+
"top_k": 20,
|
12 |
+
"top_p": 0.8,
|
13 |
+
"transformers_version": "4.40.0.dev0"
|
14 |
+
}
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
|
23 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
24 |
+
from transformers.cache_utils import Cache, DynamicCache
|
25 |
+
from transformers.generation.utils import GenerateOutput
|
26 |
+
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPast,
|
29 |
+
CausalLMOutputWithPast,
|
30 |
+
)
|
31 |
+
from transformers.utils import logging
|
32 |
+
|
33 |
+
from .cambrian_arch import CambrianMetaForCausalLM, CambrianMetaModel
|
34 |
+
|
35 |
+
IS_XLA_AVAILABLE = False
|
36 |
+
|
37 |
+
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
|
38 |
+
|
39 |
+
logger = logging.get_logger(__name__)
|
40 |
+
|
41 |
+
|
42 |
+
class CambrianConfig(Qwen2Config):
|
43 |
+
model_type = "cambrian_qwen"
|
44 |
+
|
45 |
+
debug = "debug"
|
46 |
+
|
47 |
+
|
48 |
+
class CambrianQwenModel(CambrianMetaModel, Qwen2Model):
|
49 |
+
config_class = CambrianConfig
|
50 |
+
|
51 |
+
def __init__(self, config: Qwen2Config):
|
52 |
+
super(CambrianQwenModel, self).__init__(config)
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
# pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
|
57 |
+
input_ids: torch.LongTensor = None,
|
58 |
+
attention_mask: Optional[torch.Tensor] = None,
|
59 |
+
position_ids: Optional[torch.LongTensor] = None,
|
60 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
61 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
62 |
+
use_cache: Optional[bool] = None,
|
63 |
+
output_attentions: Optional[bool] = None,
|
64 |
+
output_hidden_states: Optional[bool] = None,
|
65 |
+
return_dict: Optional[bool] = None,
|
66 |
+
cache_position: Optional[torch.LongTensor] = None,
|
67 |
+
vision_tower_aux_feature_list: Optional[List[torch.FloatTensor]] = None,
|
68 |
+
vision_tower_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
|
69 |
+
final_vision_feature_size: Optional[List[tuple]] = None,
|
70 |
+
global_context_feature: Optional[torch.Tensor] = None,
|
71 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
72 |
+
output_attentions = (
|
73 |
+
output_attentions
|
74 |
+
if output_attentions is not None
|
75 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `config`.
|
76 |
+
else self.config.output_attentions
|
77 |
+
)
|
78 |
+
output_hidden_states = (
|
79 |
+
output_hidden_states
|
80 |
+
if output_hidden_states is not None
|
81 |
+
else self.config.output_hidden_states
|
82 |
+
)
|
83 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
84 |
+
|
85 |
+
return_dict = (
|
86 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
87 |
+
)
|
88 |
+
|
89 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
90 |
+
raise ValueError(
|
91 |
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
92 |
+
)
|
93 |
+
|
94 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `gradient_checkpointing`.
|
95 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `training`.
|
96 |
+
if self.gradient_checkpointing and self.training:
|
97 |
+
if use_cache:
|
98 |
+
logger.warning_once(
|
99 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
100 |
+
)
|
101 |
+
use_cache = False
|
102 |
+
|
103 |
+
use_legacy_cache = False
|
104 |
+
if use_cache and not isinstance(past_key_values, Cache):
|
105 |
+
use_legacy_cache = True
|
106 |
+
# pyre-fixme[6]: For 1st argument expected
|
107 |
+
# `Optional[Tuple[Tuple[FloatTensor]]]` but got
|
108 |
+
# `Optional[List[FloatTensor]]`.
|
109 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
110 |
+
logger.warning_once(
|
111 |
+
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
|
112 |
+
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
|
113 |
+
)
|
114 |
+
|
115 |
+
if inputs_embeds is None:
|
116 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `embed_tokens`.
|
117 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
118 |
+
|
119 |
+
if cache_position is None:
|
120 |
+
past_seen_tokens = (
|
121 |
+
# pyre-fixme[16]: Item `List` of `Union[List[torch._C.FloatTensor],
|
122 |
+
# DynamicCache]` has no attribute `get_seq_length`.
|
123 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
124 |
+
)
|
125 |
+
cache_position = torch.arange(
|
126 |
+
past_seen_tokens,
|
127 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
128 |
+
device=inputs_embeds.device,
|
129 |
+
)
|
130 |
+
if position_ids is None:
|
131 |
+
position_ids = cache_position.unsqueeze(0)
|
132 |
+
|
133 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `_update_causal_mask`.
|
134 |
+
causal_mask = self._update_causal_mask(
|
135 |
+
attention_mask,
|
136 |
+
inputs_embeds,
|
137 |
+
cache_position,
|
138 |
+
past_key_values,
|
139 |
+
output_attentions,
|
140 |
+
)
|
141 |
+
|
142 |
+
hidden_states = inputs_embeds
|
143 |
+
|
144 |
+
# decoder layers
|
145 |
+
all_hidden_states = () if output_hidden_states else None
|
146 |
+
all_self_attns = () if output_attentions else None
|
147 |
+
next_decoder_cache = None
|
148 |
+
|
149 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `layers`.
|
150 |
+
for i, decoder_layer in enumerate(self.layers):
|
151 |
+
if output_hidden_states:
|
152 |
+
all_hidden_states += (hidden_states,)
|
153 |
+
|
154 |
+
if self.gradient_checkpointing and self.training:
|
155 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute
|
156 |
+
# `_gradient_checkpointing_func`.
|
157 |
+
layer_outputs = self._gradient_checkpointing_func(
|
158 |
+
decoder_layer.__call__,
|
159 |
+
hidden_states,
|
160 |
+
causal_mask,
|
161 |
+
position_ids,
|
162 |
+
past_key_values,
|
163 |
+
output_attentions,
|
164 |
+
use_cache,
|
165 |
+
cache_position,
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
layer_outputs = decoder_layer(
|
169 |
+
hidden_states,
|
170 |
+
attention_mask=causal_mask,
|
171 |
+
position_ids=position_ids,
|
172 |
+
past_key_value=past_key_values,
|
173 |
+
output_attentions=output_attentions,
|
174 |
+
use_cache=use_cache,
|
175 |
+
cache_position=cache_position,
|
176 |
+
)
|
177 |
+
|
178 |
+
hidden_states = layer_outputs[0]
|
179 |
+
|
180 |
+
if use_cache:
|
181 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
182 |
+
|
183 |
+
if output_attentions:
|
184 |
+
all_self_attns += (layer_outputs[1],)
|
185 |
+
|
186 |
+
# pyre-fixme[16]: `CambrianQwenModel` has no attribute `norm`.
|
187 |
+
hidden_states = self.norm(hidden_states)
|
188 |
+
|
189 |
+
# add hidden states from the last decoder layer
|
190 |
+
if output_hidden_states:
|
191 |
+
all_hidden_states += (hidden_states,)
|
192 |
+
|
193 |
+
next_cache = None
|
194 |
+
if use_cache:
|
195 |
+
next_cache = (
|
196 |
+
next_decoder_cache.to_legacy_cache()
|
197 |
+
if use_legacy_cache
|
198 |
+
else next_decoder_cache
|
199 |
+
)
|
200 |
+
|
201 |
+
if not return_dict:
|
202 |
+
return tuple(
|
203 |
+
v
|
204 |
+
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
205 |
+
if v is not None
|
206 |
+
)
|
207 |
+
return BaseModelOutputWithPast(
|
208 |
+
last_hidden_state=hidden_states,
|
209 |
+
past_key_values=next_cache,
|
210 |
+
hidden_states=all_hidden_states,
|
211 |
+
attentions=all_self_attns,
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
class CambrianQwenForCausalLM(Qwen2ForCausalLM, CambrianMetaForCausalLM):
|
216 |
+
config_class = CambrianConfig
|
217 |
+
|
218 |
+
def __init__(self, config):
|
219 |
+
# super(Qwen2ForCausalLM, self).__init__(config)
|
220 |
+
Qwen2ForCausalLM.__init__(self, config)
|
221 |
+
config.model_type = "cambrian_qwen"
|
222 |
+
config.rope_scaling = None
|
223 |
+
|
224 |
+
self.model = CambrianQwenModel(config)
|
225 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
226 |
+
# Initialize weights and apply final processing
|
227 |
+
self.post_init()
|
228 |
+
|
229 |
+
def get_model(self):
|
230 |
+
return self.model
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
# pyre-fixme[9]: input_ids has type `LongTensor`; used as `None`.
|
235 |
+
input_ids: torch.LongTensor = None,
|
236 |
+
attention_mask: Optional[torch.Tensor] = None,
|
237 |
+
position_ids: Optional[torch.LongTensor] = None,
|
238 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
239 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
240 |
+
labels: Optional[torch.LongTensor] = None,
|
241 |
+
use_cache: Optional[bool] = None,
|
242 |
+
output_attentions: Optional[bool] = None,
|
243 |
+
output_hidden_states: Optional[bool] = None,
|
244 |
+
images: Optional[torch.FloatTensor] = None,
|
245 |
+
image_aux_attention_masks_list: Optional[List[torch.Tensor]] = None,
|
246 |
+
image_sizes: Optional[List[List[int]]] = None,
|
247 |
+
return_dict: Optional[bool] = None,
|
248 |
+
modalities: Optional[List[str]] = ["image"],
|
249 |
+
dpo_forward: Optional[bool] = False,
|
250 |
+
cache_position=None,
|
251 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
252 |
+
|
253 |
+
input_image_features = None
|
254 |
+
highres_image_features = None
|
255 |
+
frame_split_sizes = None
|
256 |
+
|
257 |
+
if inputs_embeds is None:
|
258 |
+
(
|
259 |
+
input_ids,
|
260 |
+
position_ids,
|
261 |
+
attention_mask,
|
262 |
+
past_key_values,
|
263 |
+
inputs_embeds,
|
264 |
+
labels,
|
265 |
+
vision_tower_aux_feature_list,
|
266 |
+
vision_tower_aux_attention_masks_list,
|
267 |
+
final_vision_feature_size,
|
268 |
+
global_context_feature,
|
269 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
270 |
+
input_ids,
|
271 |
+
position_ids,
|
272 |
+
attention_mask,
|
273 |
+
past_key_values,
|
274 |
+
labels,
|
275 |
+
images,
|
276 |
+
image_aux_attention_masks_list,
|
277 |
+
image_sizes,
|
278 |
+
)
|
279 |
+
|
280 |
+
if dpo_forward:
|
281 |
+
# pyre-fixme[29]: `CambrianQwenModel` is not a function.
|
282 |
+
outputs = self.model(
|
283 |
+
input_ids=input_ids,
|
284 |
+
attention_mask=attention_mask,
|
285 |
+
position_ids=position_ids,
|
286 |
+
past_key_values=past_key_values,
|
287 |
+
inputs_embeds=inputs_embeds,
|
288 |
+
use_cache=use_cache,
|
289 |
+
output_attentions=output_attentions,
|
290 |
+
output_hidden_states=output_hidden_states,
|
291 |
+
return_dict=return_dict,
|
292 |
+
)
|
293 |
+
|
294 |
+
hidden_states = outputs[0]
|
295 |
+
logits = self.lm_head(hidden_states)
|
296 |
+
return logits, labels
|
297 |
+
|
298 |
+
else:
|
299 |
+
if hasattr(self, "vision_tower_aux_feature_list"):
|
300 |
+
# pyre-fixme[29]: `CambrianQwenModel` is not a function.
|
301 |
+
outputs = self.model(
|
302 |
+
input_ids=input_ids,
|
303 |
+
attention_mask=attention_mask,
|
304 |
+
position_ids=position_ids,
|
305 |
+
past_key_values=past_key_values,
|
306 |
+
inputs_embeds=inputs_embeds,
|
307 |
+
use_cache=use_cache,
|
308 |
+
output_attentions=output_attentions,
|
309 |
+
output_hidden_states=output_hidden_states,
|
310 |
+
return_dict=return_dict,
|
311 |
+
vision_tower_aux_feature_list=(
|
312 |
+
# pyre-fixme[61]: `vision_tower_aux_feature_list` is
|
313 |
+
# undefined, or not always defined.
|
314 |
+
vision_tower_aux_feature_list
|
315 |
+
if inputs_embeds is None
|
316 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
317 |
+
# `vision_tower_aux_feature_list`.
|
318 |
+
else self.vision_tower_aux_feature_list
|
319 |
+
),
|
320 |
+
vision_tower_aux_attention_masks_list=(
|
321 |
+
# pyre-fixme[61]: `vision_tower_aux_attention_masks_list` is
|
322 |
+
# undefined, or not always defined.
|
323 |
+
vision_tower_aux_attention_masks_list
|
324 |
+
if inputs_embeds is None
|
325 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
326 |
+
# `vision_tower_aux_attention_masks_list`.
|
327 |
+
else self.vision_tower_aux_attention_masks_list
|
328 |
+
),
|
329 |
+
final_vision_feature_size=(
|
330 |
+
# pyre-fixme[61]: `final_vision_feature_size` is undefined,
|
331 |
+
# or not always defined.
|
332 |
+
final_vision_feature_size
|
333 |
+
if inputs_embeds is None
|
334 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
335 |
+
# `final_vision_feature_size`.
|
336 |
+
else self.final_vision_feature_size
|
337 |
+
),
|
338 |
+
global_context_feature=(
|
339 |
+
# pyre-fixme[61]: `global_context_feature` is undefined, or
|
340 |
+
# not always defined.
|
341 |
+
global_context_feature
|
342 |
+
if inputs_embeds is None
|
343 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
344 |
+
# `global_context_feature`.
|
345 |
+
else self.global_context_feature
|
346 |
+
),
|
347 |
+
)
|
348 |
+
else:
|
349 |
+
# pyre-fixme[29]: `CambrianQwenModel` is not a function.
|
350 |
+
outputs = self.model(
|
351 |
+
input_ids=input_ids,
|
352 |
+
attention_mask=attention_mask,
|
353 |
+
position_ids=position_ids,
|
354 |
+
past_key_values=past_key_values,
|
355 |
+
inputs_embeds=inputs_embeds,
|
356 |
+
use_cache=use_cache,
|
357 |
+
output_attentions=output_attentions,
|
358 |
+
output_hidden_states=output_hidden_states,
|
359 |
+
return_dict=return_dict,
|
360 |
+
# final_vision_feature_size=final_vision_feature_size,
|
361 |
+
)
|
362 |
+
|
363 |
+
hidden_states = outputs[0]
|
364 |
+
logits = self.lm_head(hidden_states)
|
365 |
+
logits = logits.float()
|
366 |
+
|
367 |
+
loss = None
|
368 |
+
if labels is not None:
|
369 |
+
# Shift so that tokens < n predict n
|
370 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
371 |
+
shift_labels = labels[..., 1:].contiguous()
|
372 |
+
# Flatten the tokens
|
373 |
+
loss_fct = CrossEntropyLoss()
|
374 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute `config`.
|
375 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
376 |
+
shift_labels = shift_labels.view(-1)
|
377 |
+
# Enable model parallelism
|
378 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
379 |
+
loss = loss_fct(shift_logits, shift_labels)
|
380 |
+
|
381 |
+
if not return_dict:
|
382 |
+
output = (logits,) + outputs[1:]
|
383 |
+
return (loss,) + output if loss is not None else output
|
384 |
+
|
385 |
+
return CausalLMOutputWithPast(
|
386 |
+
loss=loss,
|
387 |
+
logits=logits,
|
388 |
+
past_key_values=outputs.past_key_values,
|
389 |
+
hidden_states=outputs.hidden_states,
|
390 |
+
attentions=outputs.attentions,
|
391 |
+
)
|
392 |
+
|
393 |
+
@torch.no_grad()
|
394 |
+
def generate(
|
395 |
+
self,
|
396 |
+
inputs: Optional[torch.Tensor] = None,
|
397 |
+
images: Optional[torch.Tensor] = None,
|
398 |
+
image_sizes: Optional[torch.Tensor] = None,
|
399 |
+
**kwargs,
|
400 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
401 |
+
position_ids = kwargs.pop("position_ids", None)
|
402 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
403 |
+
if "inputs_embeds" in kwargs:
|
404 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
405 |
+
|
406 |
+
if images is not None:
|
407 |
+
(
|
408 |
+
inputs,
|
409 |
+
position_ids,
|
410 |
+
attention_mask,
|
411 |
+
_,
|
412 |
+
inputs_embeds,
|
413 |
+
_,
|
414 |
+
vision_tower_aux_feature_list,
|
415 |
+
vision_tower_aux_attention_masks_list,
|
416 |
+
final_vision_feature_size,
|
417 |
+
global_context_feature,
|
418 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
419 |
+
inputs,
|
420 |
+
position_ids,
|
421 |
+
attention_mask,
|
422 |
+
None,
|
423 |
+
None,
|
424 |
+
images,
|
425 |
+
image_sizes=image_sizes,
|
426 |
+
)
|
427 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
428 |
+
# `vision_tower_aux_feature_list`.
|
429 |
+
self.vision_tower_aux_feature_list = vision_tower_aux_feature_list
|
430 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
431 |
+
# `vision_tower_aux_attention_masks_list`.
|
432 |
+
self.vision_tower_aux_attention_masks_list = (
|
433 |
+
vision_tower_aux_attention_masks_list
|
434 |
+
)
|
435 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
436 |
+
# `final_vision_feature_size`.
|
437 |
+
self.final_vision_feature_size = final_vision_feature_size
|
438 |
+
# pyre-fixme[16]: `CambrianQwenForCausalLM` has no attribute
|
439 |
+
# `global_context_feature`.
|
440 |
+
self.global_context_feature = global_context_feature
|
441 |
+
else:
|
442 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
443 |
+
|
444 |
+
# pyre-fixme[16]: `Qwen2ForCausalLM` has no attribute `generate`.
|
445 |
+
return super().generate(
|
446 |
+
position_ids=position_ids,
|
447 |
+
attention_mask=attention_mask,
|
448 |
+
inputs_embeds=inputs_embeds,
|
449 |
+
**kwargs,
|
450 |
+
)
|
451 |
+
|
452 |
+
def prepare_inputs_for_generation(
|
453 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
454 |
+
):
|
455 |
+
images = kwargs.pop("images", None)
|
456 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
457 |
+
inputs = super().prepare_inputs_for_generation(
|
458 |
+
input_ids,
|
459 |
+
past_key_values=past_key_values,
|
460 |
+
inputs_embeds=inputs_embeds,
|
461 |
+
**kwargs,
|
462 |
+
)
|
463 |
+
if images is not None:
|
464 |
+
inputs["images"] = images
|
465 |
+
if image_sizes is not None:
|
466 |
+
inputs["image_sizes"] = image_sizes
|
467 |
+
return inputs
|
468 |
+
|
469 |
+
|
470 |
+
AutoConfig.register("cambrian_qwen", CambrianConfig)
|
471 |
+
AutoModelForCausalLM.register(CambrianConfig, CambrianQwenForCausalLM)
|
multimodal_encoder_builder.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import AutoImageProcessor, Dinov2Config, Dinov2Model, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class ProcessorWrapper:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
transform,
|
14 |
+
height=378,
|
15 |
+
width=378,
|
16 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
17 |
+
):
|
18 |
+
self._crop_size = {
|
19 |
+
"height": height,
|
20 |
+
"width": width,
|
21 |
+
}
|
22 |
+
self._transforms = transform
|
23 |
+
# print(transform)
|
24 |
+
self.image_mean = image_mean
|
25 |
+
|
26 |
+
@property
|
27 |
+
def crop_size(self):
|
28 |
+
return self._crop_size
|
29 |
+
|
30 |
+
def preprocess(self, image, return_tensors="pt"):
|
31 |
+
# Ensure image is a PIL Image
|
32 |
+
output = {}
|
33 |
+
output["pixel_values"] = [self._transforms(image)]
|
34 |
+
return output
|
35 |
+
|
36 |
+
|
37 |
+
class BaseVisionTower(nn.Module):
|
38 |
+
def __init__(self, vision_tower_name, args, delay_load=False):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.is_loaded = False
|
42 |
+
self.args = args
|
43 |
+
|
44 |
+
self.vision_tower_name = vision_tower_name
|
45 |
+
self.select_layer = args.mm_vision_select_layer
|
46 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
47 |
+
self.unfreeze_mm_vision_tower = getattr(args, "unfreeze_mm_vision_tower", False)
|
48 |
+
self.delay_load = delay_load
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def load_model(self, device_map=None):
|
52 |
+
raise NotImplementedError("Subclasses must implement load_model")
|
53 |
+
|
54 |
+
@abstractmethod
|
55 |
+
def _forward(self, images):
|
56 |
+
raise NotImplementedError("Subclasses must implement forward")
|
57 |
+
|
58 |
+
def forward(self, images):
|
59 |
+
if type(images) is list:
|
60 |
+
image_features = [self._forward(image.unsqueeze(0)) for image in images]
|
61 |
+
else:
|
62 |
+
image_features = self._forward(images)
|
63 |
+
|
64 |
+
return image_features
|
65 |
+
|
66 |
+
@property
|
67 |
+
def dummy_feature(self):
|
68 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
69 |
+
|
70 |
+
@property
|
71 |
+
def dtype(self):
|
72 |
+
# Dynamically infer the dtype from the first parameter, if not explicitly specified
|
73 |
+
if hasattr(self.vision_tower, "dtype"):
|
74 |
+
return self.vision_tower.dtype
|
75 |
+
else:
|
76 |
+
params = list(self.vision_tower.parameters())
|
77 |
+
return (
|
78 |
+
params[0].dtype if len(params) > 0 else torch.float32
|
79 |
+
) # Default to torch.float32 if no parameters
|
80 |
+
|
81 |
+
@property
|
82 |
+
def device(self):
|
83 |
+
# Dynamically infer the device from the first parameter, if not explicitly specified
|
84 |
+
if hasattr(self.vision_tower, "device"):
|
85 |
+
return self.vision_tower.device
|
86 |
+
else:
|
87 |
+
params = list(self.vision_tower.parameters())
|
88 |
+
return (
|
89 |
+
params[0].device if len(params) > 0 else torch.device("cpu")
|
90 |
+
) # Default to CPU if no parameters
|
91 |
+
|
92 |
+
@property
|
93 |
+
def config(self):
|
94 |
+
if self.is_loaded:
|
95 |
+
return self.vision_tower.config
|
96 |
+
else:
|
97 |
+
return self.cfg_only
|
98 |
+
|
99 |
+
@property
|
100 |
+
def hidden_size(self):
|
101 |
+
try:
|
102 |
+
return self.config.hidden_size
|
103 |
+
except:
|
104 |
+
return self._hidden_size
|
105 |
+
|
106 |
+
@property
|
107 |
+
def image_size(self): # resolution
|
108 |
+
# return self.config.image_size
|
109 |
+
try:
|
110 |
+
return self.config.image_size
|
111 |
+
except:
|
112 |
+
return self._image_size
|
113 |
+
|
114 |
+
@property
|
115 |
+
def patch_size(self):
|
116 |
+
# return self.config.patch_size
|
117 |
+
try:
|
118 |
+
return self.config.patch_size
|
119 |
+
except:
|
120 |
+
return self._patch_size
|
121 |
+
|
122 |
+
@property
|
123 |
+
def num_patches_per_side(self):
|
124 |
+
if self._interp_size is not None:
|
125 |
+
return int(self._interp_size**0.5)
|
126 |
+
try:
|
127 |
+
return self.image_size // self.patch_size
|
128 |
+
except:
|
129 |
+
return self._num_patches_per_side
|
130 |
+
|
131 |
+
@property
|
132 |
+
def num_patches(self):
|
133 |
+
if self._interp_size is not None:
|
134 |
+
return self._interp_size
|
135 |
+
try:
|
136 |
+
return self.num_patches_per_side**2
|
137 |
+
except:
|
138 |
+
return self._num_patches
|
139 |
+
|
140 |
+
|
141 |
+
class DinoVisionTower(BaseVisionTower):
|
142 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
143 |
+
super(DinoVisionTower, self).__init__(vision_tower, args, delay_load)
|
144 |
+
|
145 |
+
model_path = "facebook/dinov2-giant"
|
146 |
+
base_model_name, res, interp = model_path, 378, 576
|
147 |
+
self._vision_tower_name = vision_tower
|
148 |
+
self.vision_tower_name = base_model_name
|
149 |
+
self._image_size = res
|
150 |
+
self._interp_size = interp
|
151 |
+
self._patch_size = 14 # default patch size
|
152 |
+
|
153 |
+
if not self.delay_load:
|
154 |
+
self.load_model()
|
155 |
+
else:
|
156 |
+
self.cfg_only = Dinov2Config.from_pretrained(self.vision_tower_name)
|
157 |
+
|
158 |
+
def load_model(self, device_map=None):
|
159 |
+
|
160 |
+
self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
|
161 |
+
"""ValueError: Dinov2Model does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute."""
|
162 |
+
self.vision_tower._no_split_modules = ["Dinov2SwiGLUFFN"]
|
163 |
+
|
164 |
+
_image_size = self.vision_tower.config.image_size
|
165 |
+
if self._image_size is None:
|
166 |
+
self._image_size = _image_size
|
167 |
+
|
168 |
+
# increase shortest edge to prevent edge case crops
|
169 |
+
default_shortest_ratio = 8 / 7 # 224/256
|
170 |
+
# shortest_edge = int(default_shortest_ratio * self._image_size)
|
171 |
+
shortest_edge = self._image_size
|
172 |
+
|
173 |
+
processor = AutoImageProcessor.from_pretrained(
|
174 |
+
self.vision_tower_name,
|
175 |
+
crop_size=dict(height=self._image_size, width=self._image_size),
|
176 |
+
size=dict(shortest_edge=shortest_edge),
|
177 |
+
)
|
178 |
+
self.image_processor = processor
|
179 |
+
|
180 |
+
# Assign the output channels of the projection convolution as the hidden size
|
181 |
+
self._hidden_size = (
|
182 |
+
self.vision_tower.embeddings.patch_embeddings.projection.out_channels
|
183 |
+
)
|
184 |
+
# Assign the first value of the stride of the projection convolution as the patch size
|
185 |
+
self._patch_size = (
|
186 |
+
self.vision_tower.embeddings.patch_embeddings.projection.stride[0]
|
187 |
+
)
|
188 |
+
|
189 |
+
# print(self._hidden_size, self._patch_size)
|
190 |
+
|
191 |
+
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
|
192 |
+
self.is_loaded = True
|
193 |
+
|
194 |
+
@property
|
195 |
+
def image_size(self):
|
196 |
+
return self._image_size
|
197 |
+
|
198 |
+
def feature_select(self, outputs):
|
199 |
+
sequence_output = outputs[
|
200 |
+
"last_hidden_state"
|
201 |
+
] # batch_size, sequence_length, hidden_size
|
202 |
+
|
203 |
+
if self.select_feature == "cls_patch":
|
204 |
+
image_features = sequence_output
|
205 |
+
elif self.select_feature == "patch":
|
206 |
+
image_features = sequence_output[:, 1:]
|
207 |
+
elif self.select_feature == "cls":
|
208 |
+
image_features = sequence_output[:, 0]
|
209 |
+
else:
|
210 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
211 |
+
return image_features
|
212 |
+
|
213 |
+
def interpolate(self, image_features):
|
214 |
+
if self._interp_size is None:
|
215 |
+
return image_features
|
216 |
+
|
217 |
+
b, num_tokens, dim = image_features.shape
|
218 |
+
|
219 |
+
if num_tokens != self.num_patches:
|
220 |
+
target_h = target_w = int(self._interp_size**0.5)
|
221 |
+
h = w = int(num_tokens**0.5)
|
222 |
+
|
223 |
+
image_features = image_features.view(b, h, w, dim)
|
224 |
+
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
225 |
+
|
226 |
+
image_features = F.interpolate(
|
227 |
+
image_features.to(torch.float32),
|
228 |
+
size=(target_h, target_w),
|
229 |
+
mode="bilinear",
|
230 |
+
align_corners=False,
|
231 |
+
).to(image_features.dtype)
|
232 |
+
|
233 |
+
# Permute the dimensions back to (b, target_h, target_w, dim)
|
234 |
+
image_features = image_features.permute(0, 2, 3, 1).contiguous()
|
235 |
+
|
236 |
+
# Flatten the spatial dimensions (target_h, target_w) into a single dimension
|
237 |
+
image_features = image_features.flatten(1, 2)
|
238 |
+
|
239 |
+
return image_features
|
240 |
+
|
241 |
+
def _forward(self, images):
|
242 |
+
# logger.warning(f"images shape: {images.shape}")
|
243 |
+
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
|
244 |
+
image_forward_outs = self.vision_tower.forward(
|
245 |
+
images.to(device=self.device, dtype=self.dtype)
|
246 |
+
)
|
247 |
+
# logger.warning(f"image_forward_outs shape: {image_forward_outs['last_hidden_state'].shape}")
|
248 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
249 |
+
# logger.warning(f"image_features shape: {image_features.shape}")
|
250 |
+
interp_features = self.interpolate(image_features)
|
251 |
+
# logger.warning(f"interp_features shape: {interp_features.shape}")
|
252 |
+
return interp_features
|
253 |
+
|
254 |
+
@property
|
255 |
+
def num_patches_per_side(self):
|
256 |
+
return int(self.num_patches**0.5)
|
257 |
+
|
258 |
+
@property
|
259 |
+
def num_patches(self):
|
260 |
+
if self._interp_size is None:
|
261 |
+
return (self._image_size // self._patch_size) ** 2
|
262 |
+
else:
|
263 |
+
return self._interp_size
|
264 |
+
|
265 |
+
|
266 |
+
# from .siglip_encoder import SiglipVisionTower
|
267 |
+
class SiglipVisionTower(BaseVisionTower):
|
268 |
+
def __init__(self, vision_tower_name, args, delay_load=False):
|
269 |
+
super(SiglipVisionTower, self).__init__(vision_tower_name, args, delay_load)
|
270 |
+
|
271 |
+
model_path = "google/siglip-so400m-patch14-384"
|
272 |
+
base_model_name, res, interp = model_path, 384, 576
|
273 |
+
self.vision_tower_name = base_model_name
|
274 |
+
self._image_size = res if res is not None else 512
|
275 |
+
self._interp_size = interp
|
276 |
+
if not self.delay_load:
|
277 |
+
self.load_model()
|
278 |
+
elif self.unfreeze_mm_vision_tower:
|
279 |
+
self.load_model()
|
280 |
+
else:
|
281 |
+
self._hidden_size = 1152
|
282 |
+
|
283 |
+
def load_model(self, device_map=None):
|
284 |
+
self.vision_model = "siglip"
|
285 |
+
# clip_model, processor = create_model_from_pretrained(self.vision_tower_name)
|
286 |
+
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
|
287 |
+
|
288 |
+
# self.vision_tower = clip_model.visual.trunk
|
289 |
+
self.vision_tower.output_tokens = True
|
290 |
+
|
291 |
+
self._hidden_size = self.vision_tower.config.hidden_size
|
292 |
+
self._image_size = self.vision_tower.config.image_size
|
293 |
+
self._patch_size = self.vision_tower.config.patch_size
|
294 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(
|
295 |
+
self.vision_tower_name
|
296 |
+
)
|
297 |
+
|
298 |
+
self.vision_tower.requires_grad_(self.unfreeze_mm_vision_tower)
|
299 |
+
self.is_loaded = True
|
300 |
+
|
301 |
+
def interpolate(self, image_features):
|
302 |
+
if self._interp_size is None:
|
303 |
+
return image_features
|
304 |
+
|
305 |
+
b, num_tokens, dim = image_features.shape
|
306 |
+
|
307 |
+
if num_tokens != self.num_patches:
|
308 |
+
target_h = target_w = int(self._interp_size**0.5)
|
309 |
+
h = w = int(num_tokens**0.5)
|
310 |
+
|
311 |
+
image_features = image_features.view(b, h, w, dim)
|
312 |
+
image_features = image_features.permute(0, 3, 1, 2).contiguous()
|
313 |
+
|
314 |
+
image_features = F.interpolate(
|
315 |
+
image_features.to(torch.float32),
|
316 |
+
size=(target_h, target_w),
|
317 |
+
mode="bilinear",
|
318 |
+
align_corners=False,
|
319 |
+
).to(image_features.dtype)
|
320 |
+
|
321 |
+
# Permute the dimensions back to (b, target_h, target_w, dim)
|
322 |
+
image_features = image_features.permute(0, 2, 3, 1).contiguous()
|
323 |
+
|
324 |
+
# Flatten the spatial dimensions (target_h, target_w) into a single dimension
|
325 |
+
image_features = image_features.flatten(1, 2)
|
326 |
+
|
327 |
+
return image_features
|
328 |
+
|
329 |
+
def _forward(self, images, interpolate_token=576):
|
330 |
+
with torch.set_grad_enabled(self.unfreeze_mm_vision_tower):
|
331 |
+
image_features = self.vision_tower.forward(
|
332 |
+
images.to(device=self.device, dtype=self.dtype),
|
333 |
+
output_hidden_states=True,
|
334 |
+
).hidden_states[-1]
|
335 |
+
interp_features = self.interpolate(image_features)
|
336 |
+
return interp_features
|
337 |
+
|
338 |
+
|
339 |
+
def build_vision_tower_aux_list(vision_tower_cfg, **kwargs):
|
340 |
+
vision_tower_aux_name_list = getattr(
|
341 |
+
vision_tower_cfg,
|
342 |
+
"mm_vision_tower_aux_list",
|
343 |
+
getattr(vision_tower_cfg, "vision_tower_aux_list", None),
|
344 |
+
)
|
345 |
+
vision_tower_aux_token_len_list = getattr(
|
346 |
+
vision_tower_cfg,
|
347 |
+
"mm_vision_tower_aux_token_len_list",
|
348 |
+
getattr(vision_tower_cfg, "vision_tower_aux_token_len_list", None),
|
349 |
+
)
|
350 |
+
vision_tower_aux_list = []
|
351 |
+
for vision_tower_aux_name, vision_tower_aux_token_len in zip(
|
352 |
+
vision_tower_aux_name_list, vision_tower_aux_token_len_list
|
353 |
+
):
|
354 |
+
config = copy.deepcopy(vision_tower_cfg)
|
355 |
+
vision_tower_aux_name += "-interp{}".format(vision_tower_aux_token_len)
|
356 |
+
if "siglip" in vision_tower_aux_name.lower():
|
357 |
+
vision_tower_aux_list.append(
|
358 |
+
SiglipVisionTower(vision_tower_aux_name, args=config, **kwargs)
|
359 |
+
)
|
360 |
+
|
361 |
+
# SSL-based Vision Towers
|
362 |
+
elif "dinov2" in vision_tower_aux_name.lower():
|
363 |
+
vision_tower_aux_list.append(
|
364 |
+
DinoVisionTower(vision_tower_aux_name, args=config, **kwargs)
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
raise ValueError(f"Unknown vision tower: {vision_tower_aux_name}")
|
368 |
+
return vision_tower_aux_list
|
multimodal_projector_builder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pyre-unsafe
|
2 |
+
import re
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class IdentityMap(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
def forward(self, x, *args, **kwargs):
|
12 |
+
return x
|
13 |
+
|
14 |
+
@property
|
15 |
+
def config(self):
|
16 |
+
return {"mm_projector_type": "identity"}
|
17 |
+
|
18 |
+
|
19 |
+
class SimpleResBlock(nn.Module):
|
20 |
+
def __init__(self, channels):
|
21 |
+
super().__init__()
|
22 |
+
self.pre_norm = nn.LayerNorm(channels)
|
23 |
+
|
24 |
+
self.proj = nn.Sequential(
|
25 |
+
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.pre_norm(x)
|
30 |
+
return x + self.proj(x)
|
31 |
+
|
32 |
+
|
33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
35 |
+
config.mm_hidden_size = 256
|
36 |
+
|
37 |
+
if projector_type == "linear":
|
38 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
39 |
+
|
40 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
41 |
+
if mlp_gelu_match:
|
42 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
43 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
44 |
+
for _ in range(1, mlp_depth):
|
45 |
+
modules.append(nn.GELU())
|
46 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
47 |
+
return nn.Sequential(*modules)
|
48 |
+
|
49 |
+
if projector_type == "identity":
|
50 |
+
return IdentityMap()
|
51 |
+
|
52 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02e0431bc1b9fdd5320ee41a5f24c194922a787282a2a6c39bd09e0d7c30f6a7
|
3 |
+
size 50329
|
pytorch_model.bin.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3584ed7ff1371bad4be307b8959d193ff3fa152164a9d47468e80245afa1c0f6
|
3 |
+
size 15343470478
|
special_tokens_map.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>"
|
5 |
+
],
|
6 |
+
"eos_token": {
|
7 |
+
"content": "<|im_end|>",
|
8 |
+
"lstrip": false,
|
9 |
+
"normalized": false,
|
10 |
+
"rstrip": false,
|
11 |
+
"single_word": false
|
12 |
+
},
|
13 |
+
"pad_token": {
|
14 |
+
"content": "<|endoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false
|
19 |
+
}
|
20 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"151643": {
|
5 |
+
"content": "<|endoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"151644": {
|
13 |
+
"content": "<|im_start|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": false,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
},
|
20 |
+
"151645": {
|
21 |
+
"content": "<|im_end|>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": false,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false,
|
26 |
+
"special": true
|
27 |
+
},
|
28 |
+
"151646": {
|
29 |
+
"content": "<image>",
|
30 |
+
"lstrip": false,
|
31 |
+
"normalized": false,
|
32 |
+
"rstrip": false,
|
33 |
+
"single_word": false,
|
34 |
+
"special": true
|
35 |
+
}
|
36 |
+
},
|
37 |
+
"additional_special_tokens": [
|
38 |
+
"<|im_start|>",
|
39 |
+
"<|im_end|>"
|
40 |
+
],
|
41 |
+
"bos_token": null,
|
42 |
+
"chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
43 |
+
"clean_up_tokenization_spaces": false,
|
44 |
+
"eos_token": "<|im_end|>",
|
45 |
+
"errors": "replace",
|
46 |
+
"model_max_length": 32768,
|
47 |
+
"pad_token": "<|endoftext|>",
|
48 |
+
"padding_side": "right",
|
49 |
+
"processor_class": "LlavaProcessor",
|
50 |
+
"split_special_tokens": false,
|
51 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
52 |
+
"unk_token": null
|
53 |
+
}
|
vision_sampler.py
ADDED
@@ -0,0 +1,566 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.utils.checkpoint
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
|
10 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
11 |
+
"""
|
12 |
+
grid_size: int of the grid height and width
|
13 |
+
return:
|
14 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
15 |
+
"""
|
16 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
17 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
18 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
19 |
+
grid = np.stack(grid, axis=0)
|
20 |
+
|
21 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
22 |
+
|
23 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
24 |
+
if cls_token:
|
25 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
26 |
+
return pos_embed
|
27 |
+
|
28 |
+
|
29 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
30 |
+
assert embed_dim % 2 == 0
|
31 |
+
|
32 |
+
# use half of dimensions to encode grid_h
|
33 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
34 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
35 |
+
|
36 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
37 |
+
return emb
|
38 |
+
|
39 |
+
|
40 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
41 |
+
"""
|
42 |
+
embed_dim: output dimension for each position
|
43 |
+
pos: a list of positions to be encoded: size (M,)
|
44 |
+
out: (M, D)
|
45 |
+
"""
|
46 |
+
assert embed_dim % 2 == 0
|
47 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
48 |
+
omega /= embed_dim / 2.0
|
49 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
50 |
+
|
51 |
+
pos = pos.reshape(-1) # (M,)
|
52 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
53 |
+
|
54 |
+
emb_sin = np.sin(out) # (M, D/2)
|
55 |
+
emb_cos = np.cos(out) # (M, D/2)
|
56 |
+
|
57 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
58 |
+
return emb
|
59 |
+
|
60 |
+
|
61 |
+
class CrossAttention(nn.Module):
|
62 |
+
|
63 |
+
def __init__(self, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False):
|
64 |
+
super().__init__()
|
65 |
+
self.hidden_dim = hidden_dim
|
66 |
+
self.num_heads = num_heads
|
67 |
+
self.head_dim = self.hidden_dim // self.num_heads
|
68 |
+
|
69 |
+
if (self.head_dim * self.num_heads) != self.hidden_dim:
|
70 |
+
raise ValueError(
|
71 |
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
|
72 |
+
f" and `num_heads`: {self.num_heads})."
|
73 |
+
)
|
74 |
+
|
75 |
+
self.q_proj = nn.Sequential(
|
76 |
+
nn.LayerNorm(q_dim),
|
77 |
+
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
78 |
+
)
|
79 |
+
self.k_proj = nn.Sequential(
|
80 |
+
nn.LayerNorm(kv_dim),
|
81 |
+
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
82 |
+
)
|
83 |
+
self.v_proj = nn.Sequential(
|
84 |
+
nn.LayerNorm(kv_dim),
|
85 |
+
nn.Linear(kv_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
86 |
+
)
|
87 |
+
self.o_proj = nn.Linear(
|
88 |
+
self.num_heads * self.head_dim, q_dim, bias=attention_bias
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, vision_latents, queries, attention_mask):
|
92 |
+
|
93 |
+
bsz, q_len, _ = queries.size()
|
94 |
+
bsz, v_len, _ = vision_latents.size()
|
95 |
+
|
96 |
+
query_states = self.q_proj(queries)
|
97 |
+
key_states = self.k_proj(vision_latents)
|
98 |
+
value_states = self.v_proj(vision_latents)
|
99 |
+
|
100 |
+
query_states = query_states.view(
|
101 |
+
bsz, q_len, self.num_heads, self.head_dim
|
102 |
+
).transpose(1, 2)
|
103 |
+
key_states = key_states.view(
|
104 |
+
bsz, v_len, self.num_heads, self.head_dim
|
105 |
+
).transpose(1, 2)
|
106 |
+
value_states = value_states.view(
|
107 |
+
bsz, v_len, self.num_heads, self.head_dim
|
108 |
+
).transpose(1, 2)
|
109 |
+
|
110 |
+
if attention_mask is not None:
|
111 |
+
if attention_mask.size() != (bsz, 1, q_len, v_len):
|
112 |
+
raise ValueError(
|
113 |
+
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
|
114 |
+
)
|
115 |
+
|
116 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
117 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
118 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
119 |
+
query_states = query_states.contiguous()
|
120 |
+
key_states = key_states.contiguous()
|
121 |
+
value_states = value_states.contiguous()
|
122 |
+
|
123 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
124 |
+
query_states,
|
125 |
+
key_states,
|
126 |
+
value_states,
|
127 |
+
attn_mask=attention_mask,
|
128 |
+
)
|
129 |
+
|
130 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
131 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
|
132 |
+
|
133 |
+
attn_output = self.o_proj(attn_output)
|
134 |
+
|
135 |
+
return attn_output
|
136 |
+
|
137 |
+
|
138 |
+
class AggregationBlock(nn.Module):
|
139 |
+
def __init__(
|
140 |
+
self, attention, q_dim, kv_dim, hidden_dim, num_heads, attention_bias=False
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
self.hidden_dim = hidden_dim
|
144 |
+
self.num_heads = num_heads
|
145 |
+
self.head_dim = self.hidden_dim // self.num_heads
|
146 |
+
|
147 |
+
if (self.head_dim * self.num_heads) != self.hidden_dim:
|
148 |
+
raise ValueError(
|
149 |
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
|
150 |
+
f" and `num_heads`: {self.num_heads})."
|
151 |
+
)
|
152 |
+
|
153 |
+
self.attention = attention
|
154 |
+
if attention:
|
155 |
+
self.attention_layer = CrossAttention(
|
156 |
+
q_dim, kv_dim, hidden_dim, num_heads, attention_bias
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
self.attention_layer = MLP(kv_dim, q_dim, q_dim)
|
160 |
+
|
161 |
+
def forward(self, vision_latents, queries, attention_mask):
|
162 |
+
if self.attention:
|
163 |
+
queries = self.attention_layer(vision_latents, queries, attention_mask)
|
164 |
+
else:
|
165 |
+
queries = self.attention_layer(vision_latents)
|
166 |
+
|
167 |
+
return queries
|
168 |
+
|
169 |
+
|
170 |
+
class MultiKVCrossAttention(nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, q_dim, kv_dim_list, hidden_dim, num_heads, attention_bias=False):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.hidden_dim = hidden_dim
|
176 |
+
self.num_heads = num_heads
|
177 |
+
self.head_dim = self.hidden_dim // self.num_heads
|
178 |
+
|
179 |
+
if (self.head_dim * self.num_heads) != self.hidden_dim:
|
180 |
+
raise ValueError(
|
181 |
+
f"hidden_dim must be divisible by num_heads (got `hidden_dim`: {self.hidden_dim}"
|
182 |
+
f" and `num_heads`: {self.num_heads})."
|
183 |
+
)
|
184 |
+
|
185 |
+
self.q_proj = nn.Sequential(
|
186 |
+
nn.LayerNorm(q_dim),
|
187 |
+
nn.Linear(q_dim, self.num_heads * self.head_dim, bias=attention_bias),
|
188 |
+
)
|
189 |
+
self.num_of_kvs = len(kv_dim_list)
|
190 |
+
for i, kv_dim in enumerate(kv_dim_list):
|
191 |
+
setattr(
|
192 |
+
self,
|
193 |
+
"k_proj_{}".format(i),
|
194 |
+
nn.Sequential(
|
195 |
+
nn.LayerNorm(kv_dim),
|
196 |
+
nn.Linear(
|
197 |
+
kv_dim, self.num_heads * self.head_dim, bias=attention_bias
|
198 |
+
),
|
199 |
+
),
|
200 |
+
)
|
201 |
+
setattr(
|
202 |
+
self,
|
203 |
+
"v_proj_{}".format(i),
|
204 |
+
nn.Sequential(
|
205 |
+
nn.LayerNorm(kv_dim),
|
206 |
+
nn.Linear(
|
207 |
+
kv_dim, self.num_heads * self.head_dim, bias=attention_bias
|
208 |
+
),
|
209 |
+
),
|
210 |
+
)
|
211 |
+
self.o_proj = nn.Linear(
|
212 |
+
self.num_heads * self.head_dim, q_dim, bias=attention_bias
|
213 |
+
)
|
214 |
+
|
215 |
+
def forward(
|
216 |
+
self,
|
217 |
+
queries,
|
218 |
+
*vision_latents_attention_mask_list,
|
219 |
+
):
|
220 |
+
|
221 |
+
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
|
222 |
+
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
|
223 |
+
|
224 |
+
bsz, q_len, _ = queries.size()
|
225 |
+
|
226 |
+
query_states = self.q_proj(queries)
|
227 |
+
key_states = torch.cat(
|
228 |
+
[
|
229 |
+
getattr(self, "k_proj_{}".format(i))(vision_latents_list[i])
|
230 |
+
for i in range(self.num_of_kvs)
|
231 |
+
],
|
232 |
+
dim=1,
|
233 |
+
)
|
234 |
+
value_states = torch.cat(
|
235 |
+
[
|
236 |
+
getattr(self, "v_proj_{}".format(i))(vision_latents_list[i])
|
237 |
+
for i in range(self.num_of_kvs)
|
238 |
+
],
|
239 |
+
dim=1,
|
240 |
+
)
|
241 |
+
|
242 |
+
v_len = key_states.shape[1]
|
243 |
+
|
244 |
+
query_states = query_states.view(
|
245 |
+
bsz, q_len, self.num_heads, self.head_dim
|
246 |
+
).transpose(1, 2)
|
247 |
+
key_states = key_states.view(
|
248 |
+
bsz, v_len, self.num_heads, self.head_dim
|
249 |
+
).transpose(1, 2)
|
250 |
+
value_states = value_states.view(
|
251 |
+
bsz, v_len, self.num_heads, self.head_dim
|
252 |
+
).transpose(1, 2)
|
253 |
+
|
254 |
+
# if kv_weight is not None:
|
255 |
+
# kv_weight = kv_weight.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
|
256 |
+
|
257 |
+
attention_mask = torch.cat(attention_mask_list, dim=-1)
|
258 |
+
|
259 |
+
if attention_mask is not None:
|
260 |
+
if attention_mask.size() != (bsz, 1, q_len, v_len):
|
261 |
+
raise ValueError(
|
262 |
+
f"Attention mask should be of size {(bsz, 1, q_len, v_len)}, but is {attention_mask.size()}"
|
263 |
+
)
|
264 |
+
|
265 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
266 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
267 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
268 |
+
query_states = query_states.contiguous()
|
269 |
+
key_states = key_states.contiguous()
|
270 |
+
value_states = value_states.contiguous()
|
271 |
+
|
272 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
273 |
+
query_states,
|
274 |
+
key_states,
|
275 |
+
value_states,
|
276 |
+
attn_mask=attention_mask,
|
277 |
+
)
|
278 |
+
# attn_output = spda(
|
279 |
+
# query_states,
|
280 |
+
# key_states,
|
281 |
+
# value_states,
|
282 |
+
# attn_mask=attention_mask,
|
283 |
+
# additional_score=kv_weight
|
284 |
+
# )
|
285 |
+
|
286 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
287 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_dim)
|
288 |
+
|
289 |
+
attn_output = self.o_proj(attn_output)
|
290 |
+
|
291 |
+
return attn_output
|
292 |
+
|
293 |
+
|
294 |
+
class MLP(nn.Module):
|
295 |
+
def __init__(self, d_in, d_hidden, d_out):
|
296 |
+
super().__init__()
|
297 |
+
self.linear_1 = nn.Linear(d_in, d_hidden, bias=False)
|
298 |
+
self.act = nn.GELU()
|
299 |
+
self.linear_2 = nn.Linear(d_hidden, d_out, bias=False)
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
return self.linear_2(self.act(self.linear_1(x)))
|
303 |
+
|
304 |
+
|
305 |
+
class VisionCrossAttentionLayer(nn.Module):
|
306 |
+
def __init__(
|
307 |
+
self,
|
308 |
+
q_dim,
|
309 |
+
context_dim,
|
310 |
+
kv_dim_list,
|
311 |
+
kv_size_list,
|
312 |
+
hidden_dim=1024,
|
313 |
+
layer_idx=0,
|
314 |
+
):
|
315 |
+
super().__init__()
|
316 |
+
num_heads = 16
|
317 |
+
self.num_of_kvs = len(kv_dim_list)
|
318 |
+
|
319 |
+
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
|
320 |
+
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
|
321 |
+
# if self.num_of_kvs > 1:
|
322 |
+
# self.weight_mlp = MLP(q_dim+hidden_dim, hidden_dim, self.num_of_kvs)
|
323 |
+
# self.tower_weight = nn.Parameter(torch.zeros((self.num_of_kvs)))
|
324 |
+
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
|
325 |
+
|
326 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
327 |
+
|
328 |
+
self.cross_attn = MultiKVCrossAttention(
|
329 |
+
hidden_dim, kv_dim_list, hidden_dim, num_heads
|
330 |
+
)
|
331 |
+
self.kv_size_list = kv_size_list
|
332 |
+
for i, kv_size in enumerate(kv_size_list):
|
333 |
+
if kv_size > 1:
|
334 |
+
setattr(
|
335 |
+
self,
|
336 |
+
"pos_embed_{}".format(i),
|
337 |
+
nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
|
338 |
+
)
|
339 |
+
# self.register_buffer("pos_embed_{}".format(i), torch.from_numpy(get_2d_sincos_pos_embed(hidden_dim, kv_size)).float(), persistent=False)
|
340 |
+
|
341 |
+
def forward(
|
342 |
+
self,
|
343 |
+
queries,
|
344 |
+
context_feature,
|
345 |
+
*vision_latents_attention_mask_list,
|
346 |
+
) -> torch.FloatTensor:
|
347 |
+
|
348 |
+
residual = queries
|
349 |
+
# queries = self.proj_in(queries)
|
350 |
+
context_feature = self.proj_context(context_feature)
|
351 |
+
# queries = queries + context_feature
|
352 |
+
queries = torch.cat([queries, context_feature], -1)
|
353 |
+
|
354 |
+
# if self.num_of_kvs > 1:
|
355 |
+
# kv_weight = self.weight_mlp(queries) # B * 1 * num_tower
|
356 |
+
# kv_weight = kv_weight + self.tower_weight.view(1, 1, -1)
|
357 |
+
# kv_weight = kv_weight.softmax(-1)
|
358 |
+
# kv_number_list = [size**2 for size in self.kv_size_list]
|
359 |
+
# kv_weight = torch.repeat_interleave(kv_weight, torch.tensor(kv_number_list).to(kv_weight.device), dim=-1)
|
360 |
+
# else:
|
361 |
+
# kv_weight = None
|
362 |
+
|
363 |
+
queries = self.proj_in(queries)
|
364 |
+
|
365 |
+
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
|
366 |
+
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
|
367 |
+
|
368 |
+
attention_mask_list_reshaped = []
|
369 |
+
if attention_mask_list is not None:
|
370 |
+
for attention_mask in attention_mask_list:
|
371 |
+
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
|
372 |
+
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
|
373 |
+
attention_mask_list_reshaped.append(attention_mask)
|
374 |
+
|
375 |
+
vision_latents_pos_list = []
|
376 |
+
for i, vision_latents in enumerate(vision_latents_list):
|
377 |
+
if vision_latents.shape[1] > 1:
|
378 |
+
vision_latents_pos_list.append(
|
379 |
+
vision_latents
|
380 |
+
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
|
381 |
+
vision_latents.dtype
|
382 |
+
)
|
383 |
+
)
|
384 |
+
else:
|
385 |
+
vision_latents_pos_list.append(vision_latents)
|
386 |
+
|
387 |
+
# Cross Attention
|
388 |
+
attention_output = self.cross_attn(
|
389 |
+
queries, *vision_latents_pos_list, *attention_mask_list_reshaped
|
390 |
+
)
|
391 |
+
|
392 |
+
# attention_output = (attention_output * combination_weight).sum(2)
|
393 |
+
queries = queries + attention_output
|
394 |
+
|
395 |
+
queries = self.norm(queries)
|
396 |
+
|
397 |
+
queries = self.proj_out(queries)
|
398 |
+
|
399 |
+
queries = queries + residual
|
400 |
+
|
401 |
+
return queries
|
402 |
+
|
403 |
+
|
404 |
+
class VisionAggregationLayer(nn.Module):
|
405 |
+
def __init__(
|
406 |
+
self,
|
407 |
+
q_dim,
|
408 |
+
context_dim,
|
409 |
+
kv_dim_list,
|
410 |
+
kv_size_list,
|
411 |
+
hidden_dim=1024,
|
412 |
+
layer_idx=0,
|
413 |
+
):
|
414 |
+
super().__init__()
|
415 |
+
num_heads = 16
|
416 |
+
self.num_of_kvs = len(kv_dim_list)
|
417 |
+
|
418 |
+
self.proj_context = nn.Linear(context_dim, hidden_dim, bias=False)
|
419 |
+
self.proj_in = nn.Linear(q_dim + hidden_dim, hidden_dim, bias=False)
|
420 |
+
|
421 |
+
self.proj_out = MLP(hidden_dim, hidden_dim, q_dim)
|
422 |
+
|
423 |
+
self.norm = nn.LayerNorm(hidden_dim)
|
424 |
+
|
425 |
+
if self.num_of_kvs > 1:
|
426 |
+
self.weight_mlp = MLP(q_dim + hidden_dim, hidden_dim, self.num_of_kvs)
|
427 |
+
|
428 |
+
for i, kv_size in enumerate(kv_size_list):
|
429 |
+
if kv_size > 1:
|
430 |
+
setattr(
|
431 |
+
self,
|
432 |
+
"pos_embed_{}".format(i),
|
433 |
+
nn.Parameter(torch.randn(kv_size**2, hidden_dim)),
|
434 |
+
)
|
435 |
+
setattr(
|
436 |
+
self,
|
437 |
+
"aggregate_{}".format(i),
|
438 |
+
AggregationBlock(
|
439 |
+
True, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
|
440 |
+
),
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
setattr(
|
444 |
+
self,
|
445 |
+
"aggregate_{}".format(i),
|
446 |
+
AggregationBlock(
|
447 |
+
False, hidden_dim, kv_dim_list[i], hidden_dim, num_heads
|
448 |
+
),
|
449 |
+
)
|
450 |
+
|
451 |
+
def forward(
|
452 |
+
self,
|
453 |
+
queries,
|
454 |
+
context_feature,
|
455 |
+
*vision_latents_attention_mask_list,
|
456 |
+
) -> torch.FloatTensor:
|
457 |
+
|
458 |
+
residual = queries
|
459 |
+
# queries = self.proj_in(queries)
|
460 |
+
context_feature = self.proj_context(context_feature)
|
461 |
+
# queries = queries + context_feature
|
462 |
+
queries = torch.cat([queries, context_feature], -1)
|
463 |
+
|
464 |
+
if self.num_of_kvs > 1:
|
465 |
+
combination_weight = self.weight_mlp(queries).softmax(
|
466 |
+
-1
|
467 |
+
) # B * 1 * num_tower
|
468 |
+
combination_weight = combination_weight.unsqueeze(-1)
|
469 |
+
else:
|
470 |
+
combination_weight = 1
|
471 |
+
|
472 |
+
queries = self.proj_in(queries)
|
473 |
+
|
474 |
+
vision_latents_list = vision_latents_attention_mask_list[: self.num_of_kvs]
|
475 |
+
attention_mask_list = vision_latents_attention_mask_list[self.num_of_kvs :]
|
476 |
+
|
477 |
+
attention_mask_list_reshaped = []
|
478 |
+
if attention_mask_list is not None:
|
479 |
+
for attention_mask in attention_mask_list:
|
480 |
+
attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1)
|
481 |
+
attention_mask = attention_mask.expand(-1, -1, queries.shape[1], -1)
|
482 |
+
attention_mask_list_reshaped.append(attention_mask)
|
483 |
+
|
484 |
+
vision_latents_pos_list = []
|
485 |
+
for i, vision_latents in enumerate(vision_latents_list):
|
486 |
+
if vision_latents.shape[1] > 1:
|
487 |
+
vision_latents_pos_list.append(
|
488 |
+
vision_latents
|
489 |
+
+ getattr(self, "pos_embed_{}".format(i))[None, :, :].to(
|
490 |
+
vision_latents.dtype
|
491 |
+
)
|
492 |
+
)
|
493 |
+
else:
|
494 |
+
vision_latents_pos_list.append(vision_latents)
|
495 |
+
|
496 |
+
aggregated_vision_latents_list = []
|
497 |
+
for i, (vision_latents, attention_mask) in enumerate(
|
498 |
+
zip(vision_latents_pos_list, attention_mask_list_reshaped)
|
499 |
+
):
|
500 |
+
aggregated_vision_latents_list.append(
|
501 |
+
getattr(self, "aggregate_{}".format(i))(
|
502 |
+
vision_latents, queries, attention_mask
|
503 |
+
)
|
504 |
+
)
|
505 |
+
|
506 |
+
aggregated_vision_latents = torch.stack(aggregated_vision_latents_list, 2)
|
507 |
+
|
508 |
+
queries = queries + (aggregated_vision_latents * combination_weight).sum(2)
|
509 |
+
|
510 |
+
queries = self.norm(queries)
|
511 |
+
|
512 |
+
queries = self.proj_out(queries)
|
513 |
+
|
514 |
+
queries = queries + residual
|
515 |
+
|
516 |
+
return queries
|
517 |
+
|
518 |
+
|
519 |
+
class VisionTokenSampler(nn.Module):
|
520 |
+
def __init__(
|
521 |
+
self,
|
522 |
+
q_dim,
|
523 |
+
context_dim,
|
524 |
+
kv_dim_list,
|
525 |
+
kv_size_list,
|
526 |
+
vision_hidden_size,
|
527 |
+
num_of_layers=1,
|
528 |
+
layer_type="joint",
|
529 |
+
):
|
530 |
+
super().__init__()
|
531 |
+
assert layer_type in ["joint", "sep"]
|
532 |
+
if layer_type == "joint":
|
533 |
+
self.layers = nn.ModuleList(
|
534 |
+
[
|
535 |
+
VisionCrossAttentionLayer(
|
536 |
+
q_dim,
|
537 |
+
context_dim,
|
538 |
+
kv_dim_list,
|
539 |
+
kv_size_list,
|
540 |
+
vision_hidden_size,
|
541 |
+
idx,
|
542 |
+
)
|
543 |
+
for idx in range(num_of_layers)
|
544 |
+
]
|
545 |
+
)
|
546 |
+
else:
|
547 |
+
self.layers = nn.ModuleList(
|
548 |
+
[
|
549 |
+
VisionAggregationLayer(
|
550 |
+
q_dim,
|
551 |
+
context_dim,
|
552 |
+
kv_dim_list,
|
553 |
+
kv_size_list,
|
554 |
+
vision_hidden_size,
|
555 |
+
idx,
|
556 |
+
)
|
557 |
+
for idx in range(num_of_layers)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
|
561 |
+
def forward(self, queries, context_feature, *vision_latents_attention_mask_list):
|
562 |
+
for layer in self.layers:
|
563 |
+
queries = layer(
|
564 |
+
queries, context_feature, *vision_latents_attention_mask_list
|
565 |
+
)
|
566 |
+
return queries
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|