howard-hou commited on
Commit
7bcd65d
1 Parent(s): 26f043f

Update modeling_vision.py

Browse files
Files changed (1) hide show
  1. modeling_vision.py +6 -6
modeling_vision.py CHANGED
@@ -28,14 +28,14 @@ class VisionEncoder(nn.Module):
28
  return self.proj(image_features)
29
 
30
  def grid_pooling(self, image_features):
 
 
31
  if self.args.grid_size == -1: # no grid pooling
32
- return image_features
33
  if self.args.grid_size == 0: # take cls token
34
- return image_features[:, 0:1, :]
35
  if self.args.grid_size == 1: # global avg pooling
36
- return image_features.mean(dim=1, keepdim=True)
37
- cls_features = image_features[:, 0:1, :]
38
- image_features = image_features[:, 1:, :] #drop cls token
39
  B, L, D = image_features.shape
40
  H_or_W = int(L**0.5)
41
  image_features = image_features.view(B, H_or_W, H_or_W, D)
@@ -45,4 +45,4 @@ class VisionEncoder(nn.Module):
45
  kernel_size=grid_stride,
46
  stride=grid_stride)
47
  image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
48
- return torch.cat((cls_features, image_features), dim=1)
 
28
  return self.proj(image_features)
29
 
30
  def grid_pooling(self, image_features):
31
+ cls_features = image_features[:, 0:1, :]
32
+ image_features = image_features[:, 1:, :] #drop cls token
33
  if self.args.grid_size == -1: # no grid pooling
34
+ return torch.cat((image_features, cls_features), dim=1)
35
  if self.args.grid_size == 0: # take cls token
36
+ return cls_features
37
  if self.args.grid_size == 1: # global avg pooling
38
+ return torch.cat((image_features.mean(dim=1, keepdim=True), cls_features), dim=1)
 
 
39
  B, L, D = image_features.shape
40
  H_or_W = int(L**0.5)
41
  image_features = image_features.view(B, H_or_W, H_or_W, D)
 
45
  kernel_size=grid_stride,
46
  stride=grid_stride)
47
  image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
48
+ return torch.cat((image_features, cls_features), dim=1)