navit style ratio preserving image treatment

#2
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +2 -2
  3. modeling_siglip.py +76 -23
config.json CHANGED
@@ -20,7 +20,7 @@
20
  "transformers_version": "4.37.0.dev0",
21
  "vision_config": {
22
  "hidden_size": 1152,
23
- "image_size": 384,
24
  "intermediate_size": 4304,
25
  "model_type": "siglip_vision_model",
26
  "num_attention_heads": 16,
 
20
  "transformers_version": "4.37.0.dev0",
21
  "vision_config": {
22
  "hidden_size": 1152,
23
+ "image_size": 980,
24
  "intermediate_size": 4304,
25
  "model_type": "siglip_vision_model",
26
  "num_attention_heads": 16,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ea2abad2b7f8a9c1aa5e49a244d5d57ffa71c56f720c94bc5d240ef4d6e1d94a
3
- size 3511950624
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ccea61f0d7617845a66fdf30bf2bcf0a090f7c74e8f7da2bf7b76e41ae4dfbc
3
+ size 3531170592
modeling_siglip.py CHANGED
@@ -283,16 +283,44 @@ class SiglipVisionEmbeddings(nn.Module):
283
  padding="valid",
284
  )
285
 
286
- self.num_patches = (self.image_size // self.patch_size) ** 2
 
287
  self.num_positions = self.num_patches
288
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
289
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
290
 
291
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
292
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
 
 
293
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
294
 
295
- embeddings = embeddings + self.position_embedding(self.position_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  return embeddings
297
 
298
 
@@ -675,12 +703,12 @@ class SiglipPreTrainedModel(PreTrainedModel):
675
 
676
  def _init_weights(self, module):
677
  """Initialize the weights"""
678
-
679
  if isinstance(module, SiglipVisionEmbeddings):
680
  width = (
681
  self.config.vision_config.hidden_size
682
- if isinstance(self.config, SiglipConfig)
683
- else self.config.hidden_size
684
  )
685
  nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
686
  elif isinstance(module, nn.Embedding):
@@ -1055,6 +1083,7 @@ class SiglipVisionTransformer(nn.Module):
1055
  def forward(
1056
  self,
1057
  pixel_values,
 
1058
  output_attentions: Optional[bool] = None,
1059
  output_hidden_states: Optional[bool] = None,
1060
  return_dict: Optional[bool] = None,
@@ -1069,10 +1098,29 @@ class SiglipVisionTransformer(nn.Module):
1069
  )
1070
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1071
 
1072
- hidden_states = self.embeddings(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1073
 
1074
  encoder_outputs = self.encoder(
1075
  inputs_embeds=hidden_states,
 
 
 
 
 
1076
  output_attentions=output_attentions,
1077
  output_hidden_states=output_hidden_states,
1078
  return_dict=return_dict,
@@ -1081,7 +1129,10 @@ class SiglipVisionTransformer(nn.Module):
1081
  last_hidden_state = encoder_outputs[0]
1082
  last_hidden_state = self.post_layernorm(last_hidden_state)
1083
 
1084
- pooled_output = self.head(last_hidden_state)
 
 
 
1085
 
1086
  if not return_dict:
1087
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
@@ -1105,11 +1156,13 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1105
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1106
  self.mlp = SiglipMLP(config)
1107
 
1108
- def forward(self, hidden_state):
1109
  batch_size = hidden_state.shape[0]
1110
  probe = self.probe.repeat(batch_size, 1, 1)
1111
 
1112
- hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
 
 
1113
 
1114
  residual = hidden_state
1115
  hidden_state = self.layernorm(hidden_state)
@@ -1185,17 +1238,17 @@ class SiglipModel(SiglipPreTrainedModel):
1185
  def __init__(self, config: SiglipConfig):
1186
  super().__init__(config)
1187
 
1188
- if not isinstance(config.text_config, SiglipTextConfig):
1189
- raise ValueError(
1190
- "config.text_config is expected to be of type SiglipTextConfig but is of type"
1191
- f" {type(config.text_config)}."
1192
- )
1193
-
1194
- if not isinstance(config.vision_config, SiglipVisionConfig):
1195
- raise ValueError(
1196
- "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1197
- f" {type(config.vision_config)}."
1198
- )
1199
 
1200
  text_config = config.text_config
1201
  vision_config = config.vision_config
 
283
  padding="valid",
284
  )
285
 
286
+ self.num_patches_per_side = self.image_size // self.patch_size
287
+ self.num_patches = self.num_patches_per_side**2
288
  self.num_positions = self.num_patches
289
  self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
 
290
 
291
+ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
292
+ batch_size = pixel_values.size(0)
293
+
294
+ patch_embeds = self.patch_embedding(pixel_values)
295
  embeddings = patch_embeds.flatten(2).transpose(1, 2)
296
 
297
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
298
+ max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
299
+ boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
300
+ position_ids = torch.full(
301
+ size=(
302
+ batch_size,
303
+ max_nb_patches_h * max_nb_patches_w,
304
+ ),
305
+ fill_value=0,
306
+ )
307
+
308
+ for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
309
+ nb_patches_h = p_attn_mask[:, 0].sum()
310
+ nb_patches_w = p_attn_mask[0].sum()
311
+
312
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
313
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
314
+
315
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
316
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
317
+
318
+ pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
319
+ position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
320
+
321
+ position_ids = position_ids.to(self.position_embedding.weight.device)
322
+
323
+ embeddings = embeddings + self.position_embedding(position_ids)
324
  return embeddings
325
 
326
 
 
703
 
704
  def _init_weights(self, module):
705
  """Initialize the weights"""
706
+
707
  if isinstance(module, SiglipVisionEmbeddings):
708
  width = (
709
  self.config.vision_config.hidden_size
710
+ # if isinstance(self.config, SiglipConfig)
711
+ # else self.config.hidden_size
712
  )
713
  nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
714
  elif isinstance(module, nn.Embedding):
 
1083
  def forward(
1084
  self,
1085
  pixel_values,
1086
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
1087
  output_attentions: Optional[bool] = None,
1088
  output_hidden_states: Optional[bool] = None,
1089
  return_dict: Optional[bool] = None,
 
1098
  )
1099
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1100
 
1101
+ batch_size = pixel_values.size(0)
1102
+ if patch_attention_mask is None:
1103
+ patch_attention_mask = torch.ones(
1104
+ size=(
1105
+ batch_size,
1106
+ pixel_values.size(2) // self.config.patch_size,
1107
+ pixel_values.size(3) // self.config.patch_size,
1108
+ ),
1109
+ dtype=torch.bool,
1110
+ device=pixel_values.device,
1111
+ )
1112
+
1113
+ hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
1114
+
1115
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
1116
 
1117
  encoder_outputs = self.encoder(
1118
  inputs_embeds=hidden_states,
1119
+ attention_mask=(
1120
+ _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
1121
+ if not self.config._flash_attn_2_enabled
1122
+ else patch_attention_mask
1123
+ ),
1124
  output_attentions=output_attentions,
1125
  output_hidden_states=output_hidden_states,
1126
  return_dict=return_dict,
 
1129
  last_hidden_state = encoder_outputs[0]
1130
  last_hidden_state = self.post_layernorm(last_hidden_state)
1131
 
1132
+ pooled_output = self.head(
1133
+ hidden_state=last_hidden_state,
1134
+ attention_mask=patch_attention_mask,
1135
+ )
1136
 
1137
  if not return_dict:
1138
  return (last_hidden_state, pooled_output) + encoder_outputs[1:]
 
1156
  self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1157
  self.mlp = SiglipMLP(config)
1158
 
1159
+ def forward(self, hidden_state, attention_mask):
1160
  batch_size = hidden_state.shape[0]
1161
  probe = self.probe.repeat(batch_size, 1, 1)
1162
 
1163
+ hidden_state = self.attention(
1164
+ query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
1165
+ )[0]
1166
 
1167
  residual = hidden_state
1168
  hidden_state = self.layernorm(hidden_state)
 
1238
  def __init__(self, config: SiglipConfig):
1239
  super().__init__(config)
1240
 
1241
+ # if not isinstance(config.text_config, SiglipTextConfig):
1242
+ # raise ValueError(
1243
+ # "config.text_config is expected to be of type SiglipTextConfig but is of type"
1244
+ # f" {type(config.text_config)}."
1245
+ # )
1246
+
1247
+ # if not isinstance(config.vision_config, SiglipVisionConfig):
1248
+ # raise ValueError(
1249
+ # "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1250
+ # f" {type(config.vision_config)}."
1251
+ # )
1252
 
1253
  text_config = config.text_config
1254
  vision_config = config.vision_config