navit style ratio preserving image treatment
#2
by
VictorSanh
- opened
- config.json +1 -1
- model.safetensors +2 -2
- modeling_siglip.py +76 -23
config.json
CHANGED
@@ -20,7 +20,7 @@
|
|
20 |
"transformers_version": "4.37.0.dev0",
|
21 |
"vision_config": {
|
22 |
"hidden_size": 1152,
|
23 |
-
"image_size":
|
24 |
"intermediate_size": 4304,
|
25 |
"model_type": "siglip_vision_model",
|
26 |
"num_attention_heads": 16,
|
|
|
20 |
"transformers_version": "4.37.0.dev0",
|
21 |
"vision_config": {
|
22 |
"hidden_size": 1152,
|
23 |
+
"image_size": 980,
|
24 |
"intermediate_size": 4304,
|
25 |
"model_type": "siglip_vision_model",
|
26 |
"num_attention_heads": 16,
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ccea61f0d7617845a66fdf30bf2bcf0a090f7c74e8f7da2bf7b76e41ae4dfbc
|
3 |
+
size 3531170592
|
modeling_siglip.py
CHANGED
@@ -283,16 +283,44 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
283 |
padding="valid",
|
284 |
)
|
285 |
|
286 |
-
self.
|
|
|
287 |
self.num_positions = self.num_patches
|
288 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
289 |
-
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
290 |
|
291 |
-
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
292 |
-
|
|
|
|
|
293 |
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
294 |
|
295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
return embeddings
|
297 |
|
298 |
|
@@ -675,12 +703,12 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
675 |
|
676 |
def _init_weights(self, module):
|
677 |
"""Initialize the weights"""
|
678 |
-
|
679 |
if isinstance(module, SiglipVisionEmbeddings):
|
680 |
width = (
|
681 |
self.config.vision_config.hidden_size
|
682 |
-
if isinstance(self.config, SiglipConfig)
|
683 |
-
else self.config.hidden_size
|
684 |
)
|
685 |
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
686 |
elif isinstance(module, nn.Embedding):
|
@@ -1055,6 +1083,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
1055 |
def forward(
|
1056 |
self,
|
1057 |
pixel_values,
|
|
|
1058 |
output_attentions: Optional[bool] = None,
|
1059 |
output_hidden_states: Optional[bool] = None,
|
1060 |
return_dict: Optional[bool] = None,
|
@@ -1069,10 +1098,29 @@ class SiglipVisionTransformer(nn.Module):
|
|
1069 |
)
|
1070 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1071 |
|
1072 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1073 |
|
1074 |
encoder_outputs = self.encoder(
|
1075 |
inputs_embeds=hidden_states,
|
|
|
|
|
|
|
|
|
|
|
1076 |
output_attentions=output_attentions,
|
1077 |
output_hidden_states=output_hidden_states,
|
1078 |
return_dict=return_dict,
|
@@ -1081,7 +1129,10 @@ class SiglipVisionTransformer(nn.Module):
|
|
1081 |
last_hidden_state = encoder_outputs[0]
|
1082 |
last_hidden_state = self.post_layernorm(last_hidden_state)
|
1083 |
|
1084 |
-
pooled_output = self.head(
|
|
|
|
|
|
|
1085 |
|
1086 |
if not return_dict:
|
1087 |
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
@@ -1105,11 +1156,13 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
1105 |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
1106 |
self.mlp = SiglipMLP(config)
|
1107 |
|
1108 |
-
def forward(self, hidden_state):
|
1109 |
batch_size = hidden_state.shape[0]
|
1110 |
probe = self.probe.repeat(batch_size, 1, 1)
|
1111 |
|
1112 |
-
hidden_state = self.attention(
|
|
|
|
|
1113 |
|
1114 |
residual = hidden_state
|
1115 |
hidden_state = self.layernorm(hidden_state)
|
@@ -1185,17 +1238,17 @@ class SiglipModel(SiglipPreTrainedModel):
|
|
1185 |
def __init__(self, config: SiglipConfig):
|
1186 |
super().__init__(config)
|
1187 |
|
1188 |
-
if not isinstance(config.text_config, SiglipTextConfig):
|
1189 |
-
|
1190 |
-
|
1191 |
-
|
1192 |
-
|
1193 |
-
|
1194 |
-
if not isinstance(config.vision_config, SiglipVisionConfig):
|
1195 |
-
|
1196 |
-
|
1197 |
-
|
1198 |
-
|
1199 |
|
1200 |
text_config = config.text_config
|
1201 |
vision_config = config.vision_config
|
|
|
283 |
padding="valid",
|
284 |
)
|
285 |
|
286 |
+
self.num_patches_per_side = self.image_size // self.patch_size
|
287 |
+
self.num_patches = self.num_patches_per_side**2
|
288 |
self.num_positions = self.num_patches
|
289 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
|
290 |
|
291 |
+
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
|
292 |
+
batch_size = pixel_values.size(0)
|
293 |
+
|
294 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
295 |
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
296 |
|
297 |
+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
298 |
+
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
|
299 |
+
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
|
300 |
+
position_ids = torch.full(
|
301 |
+
size=(
|
302 |
+
batch_size,
|
303 |
+
max_nb_patches_h * max_nb_patches_w,
|
304 |
+
),
|
305 |
+
fill_value=0,
|
306 |
+
)
|
307 |
+
|
308 |
+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
309 |
+
nb_patches_h = p_attn_mask[:, 0].sum()
|
310 |
+
nb_patches_w = p_attn_mask[0].sum()
|
311 |
+
|
312 |
+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
313 |
+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
314 |
+
|
315 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
316 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
317 |
+
|
318 |
+
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
319 |
+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
320 |
+
|
321 |
+
position_ids = position_ids.to(self.position_embedding.weight.device)
|
322 |
+
|
323 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
324 |
return embeddings
|
325 |
|
326 |
|
|
|
703 |
|
704 |
def _init_weights(self, module):
|
705 |
"""Initialize the weights"""
|
706 |
+
|
707 |
if isinstance(module, SiglipVisionEmbeddings):
|
708 |
width = (
|
709 |
self.config.vision_config.hidden_size
|
710 |
+
# if isinstance(self.config, SiglipConfig)
|
711 |
+
# else self.config.hidden_size
|
712 |
)
|
713 |
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
714 |
elif isinstance(module, nn.Embedding):
|
|
|
1083 |
def forward(
|
1084 |
self,
|
1085 |
pixel_values,
|
1086 |
+
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
1087 |
output_attentions: Optional[bool] = None,
|
1088 |
output_hidden_states: Optional[bool] = None,
|
1089 |
return_dict: Optional[bool] = None,
|
|
|
1098 |
)
|
1099 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1100 |
|
1101 |
+
batch_size = pixel_values.size(0)
|
1102 |
+
if patch_attention_mask is None:
|
1103 |
+
patch_attention_mask = torch.ones(
|
1104 |
+
size=(
|
1105 |
+
batch_size,
|
1106 |
+
pixel_values.size(2) // self.config.patch_size,
|
1107 |
+
pixel_values.size(3) // self.config.patch_size,
|
1108 |
+
),
|
1109 |
+
dtype=torch.bool,
|
1110 |
+
device=pixel_values.device,
|
1111 |
+
)
|
1112 |
+
|
1113 |
+
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
1114 |
+
|
1115 |
+
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
1116 |
|
1117 |
encoder_outputs = self.encoder(
|
1118 |
inputs_embeds=hidden_states,
|
1119 |
+
attention_mask=(
|
1120 |
+
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
1121 |
+
if not self.config._flash_attn_2_enabled
|
1122 |
+
else patch_attention_mask
|
1123 |
+
),
|
1124 |
output_attentions=output_attentions,
|
1125 |
output_hidden_states=output_hidden_states,
|
1126 |
return_dict=return_dict,
|
|
|
1129 |
last_hidden_state = encoder_outputs[0]
|
1130 |
last_hidden_state = self.post_layernorm(last_hidden_state)
|
1131 |
|
1132 |
+
pooled_output = self.head(
|
1133 |
+
hidden_state=last_hidden_state,
|
1134 |
+
attention_mask=patch_attention_mask,
|
1135 |
+
)
|
1136 |
|
1137 |
if not return_dict:
|
1138 |
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
1156 |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
1157 |
self.mlp = SiglipMLP(config)
|
1158 |
|
1159 |
+
def forward(self, hidden_state, attention_mask):
|
1160 |
batch_size = hidden_state.shape[0]
|
1161 |
probe = self.probe.repeat(batch_size, 1, 1)
|
1162 |
|
1163 |
+
hidden_state = self.attention(
|
1164 |
+
query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
|
1165 |
+
)[0]
|
1166 |
|
1167 |
residual = hidden_state
|
1168 |
hidden_state = self.layernorm(hidden_state)
|
|
|
1238 |
def __init__(self, config: SiglipConfig):
|
1239 |
super().__init__(config)
|
1240 |
|
1241 |
+
# if not isinstance(config.text_config, SiglipTextConfig):
|
1242 |
+
# raise ValueError(
|
1243 |
+
# "config.text_config is expected to be of type SiglipTextConfig but is of type"
|
1244 |
+
# f" {type(config.text_config)}."
|
1245 |
+
# )
|
1246 |
+
|
1247 |
+
# if not isinstance(config.vision_config, SiglipVisionConfig):
|
1248 |
+
# raise ValueError(
|
1249 |
+
# "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
|
1250 |
+
# f" {type(config.vision_config)}."
|
1251 |
+
# )
|
1252 |
|
1253 |
text_config = config.text_config
|
1254 |
vision_config = config.vision_config
|