gmastrapas
commited on
Commit
•
952897b
1
Parent(s):
b845577
feat: add autocasting in vision.patch_embed
Browse files- eva_model.py +2 -1
eva_model.py
CHANGED
@@ -462,13 +462,14 @@ class PatchEmbed(nn.Module):
|
|
462 |
)
|
463 |
|
464 |
def forward(self, x, **kwargs):
|
|
|
465 |
B, C, H, W = x.shape
|
466 |
# FIXME look at relaxing size constraints
|
467 |
assert H == self.img_size[0] and W == self.img_size[1], (
|
468 |
f"Input image size ({H}*{W}) doesn't match model "
|
469 |
f'({self.img_size[0]}*{self.img_size[1]}).'
|
470 |
)
|
471 |
-
x = self.proj(x).flatten(2).transpose(1, 2)
|
472 |
return x
|
473 |
|
474 |
|
|
|
462 |
)
|
463 |
|
464 |
def forward(self, x, **kwargs):
|
465 |
+
target_dtype = self.proj.weight.dtype
|
466 |
B, C, H, W = x.shape
|
467 |
# FIXME look at relaxing size constraints
|
468 |
assert H == self.img_size[0] and W == self.img_size[1], (
|
469 |
f"Input image size ({H}*{W}) doesn't match model "
|
470 |
f'({self.img_size[0]}*{self.img_size[1]}).'
|
471 |
)
|
472 |
+
x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
|
473 |
return x
|
474 |
|
475 |
|