VictorSanh
commited on
Commit
•
8084b2d
1
Parent(s):
a9d91fb
formatting
Browse files- modeling_siglip.py +24 -22
modeling_siglip.py
CHANGED
@@ -284,7 +284,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
284 |
)
|
285 |
|
286 |
self.num_patches_per_side = self.image_size // self.patch_size
|
287 |
-
self.num_patches = self.num_patches_per_side
|
288 |
self.num_positions = self.num_patches
|
289 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
290 |
|
@@ -295,16 +295,22 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
295 |
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
296 |
|
297 |
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
298 |
-
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
299 |
-
boundaries = torch.arange(1/self.num_patches_per_side, 1
|
300 |
-
position_ids = torch.full(
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
303 |
nb_patches_h = p_attn_mask[:, 0].sum()
|
304 |
nb_patches_w = p_attn_mask[0].sum()
|
305 |
|
306 |
-
fractional_coords_h = torch.arange(0, 1-1e-6, 1/nb_patches_h)
|
307 |
-
fractional_coords_w = torch.arange(0, 1-1e-6, 1/nb_patches_w)
|
308 |
|
309 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
310 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
@@ -1095,27 +1101,26 @@ class SiglipVisionTransformer(nn.Module):
|
|
1095 |
batch_size = pixel_values.size(0)
|
1096 |
if patch_attention_mask is None:
|
1097 |
patch_attention_mask = torch.ones(
|
1098 |
-
size=(
|
|
|
|
|
|
|
|
|
1099 |
dtype=torch.bool,
|
1100 |
device=pixel_values.device,
|
1101 |
)
|
1102 |
-
# if pixel_attention_mask is None:
|
1103 |
-
# # assuming `pixel_attention_mask` is of size bs x h x w
|
1104 |
-
# pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
|
1105 |
-
|
1106 |
-
# subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
1107 |
-
# patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
1108 |
|
1109 |
-
hidden_states = self.embeddings(
|
1110 |
-
pixel_values=pixel_values,
|
1111 |
-
patch_attention_mask=patch_attention_mask
|
1112 |
-
)
|
1113 |
|
1114 |
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
1115 |
|
1116 |
encoder_outputs = self.encoder(
|
1117 |
inputs_embeds=hidden_states,
|
1118 |
-
attention_mask=
|
|
|
|
|
|
|
|
|
1119 |
output_attentions=output_attentions,
|
1120 |
output_hidden_states=output_hidden_states,
|
1121 |
return_dict=return_dict,
|
@@ -1156,10 +1161,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
1156 |
probe = self.probe.repeat(batch_size, 1, 1)
|
1157 |
|
1158 |
hidden_state = self.attention(
|
1159 |
-
query=probe,
|
1160 |
-
key=hidden_state,
|
1161 |
-
value=hidden_state,
|
1162 |
-
key_padding_mask=~attention_mask
|
1163 |
)[0]
|
1164 |
|
1165 |
residual = hidden_state
|
|
|
284 |
)
|
285 |
|
286 |
self.num_patches_per_side = self.image_size // self.patch_size
|
287 |
+
self.num_patches = self.num_patches_per_side**2
|
288 |
self.num_positions = self.num_patches
|
289 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
290 |
|
|
|
295 |
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
296 |
|
297 |
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
298 |
+
max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size
|
299 |
+
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
|
300 |
+
position_ids = torch.full(
|
301 |
+
size=(
|
302 |
+
batch_size,
|
303 |
+
max_nb_patches_h * max_nb_patches_w,
|
304 |
+
),
|
305 |
+
fill_value=0,
|
306 |
+
)
|
307 |
|
308 |
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
309 |
nb_patches_h = p_attn_mask[:, 0].sum()
|
310 |
nb_patches_w = p_attn_mask[0].sum()
|
311 |
|
312 |
+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
313 |
+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
314 |
|
315 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
316 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
|
|
1101 |
batch_size = pixel_values.size(0)
|
1102 |
if patch_attention_mask is None:
|
1103 |
patch_attention_mask = torch.ones(
|
1104 |
+
size=(
|
1105 |
+
batch_size,
|
1106 |
+
pixel_values.size(2) // self.config.patch_size,
|
1107 |
+
pixel_values.size(3) // self.config.patch_size,
|
1108 |
+
),
|
1109 |
dtype=torch.bool,
|
1110 |
device=pixel_values.device,
|
1111 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
|
1113 |
+
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
|
|
|
|
|
|
1114 |
|
1115 |
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
1116 |
|
1117 |
encoder_outputs = self.encoder(
|
1118 |
inputs_embeds=hidden_states,
|
1119 |
+
attention_mask=(
|
1120 |
+
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
1121 |
+
if not self.config._flash_attn_2_enabled
|
1122 |
+
else patch_attention_mask
|
1123 |
+
),
|
1124 |
output_attentions=output_attentions,
|
1125 |
output_hidden_states=output_hidden_states,
|
1126 |
return_dict=return_dict,
|
|
|
1161 |
probe = self.probe.repeat(batch_size, 1, 1)
|
1162 |
|
1163 |
hidden_state = self.attention(
|
1164 |
+
query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask
|
|
|
|
|
|
|
1165 |
)[0]
|
1166 |
|
1167 |
residual = hidden_state
|