VictorSanh commited on
Commit
ce019df
1 Parent(s): a8b0561

fix pos_ids

Browse files
Files changed (1) hide show
  1. modeling_siglip.py +1 -1
modeling_siglip.py CHANGED
@@ -323,7 +323,7 @@ class SiglipVisionEmbeddings(nn.Module):
323
  bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
324
  bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
325
 
326
- pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
327
  position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
328
 
329
  position_ids = position_ids.to(self.position_embedding.weight.device)
 
323
  bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
324
  bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
325
 
326
+ pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
327
  position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
328
 
329
  position_ids = position_ids.to(self.position_embedding.weight.device)