add to mps device in _build_GOT_vision
Browse files- got_vision_b.py +4 -2
got_vision_b.py
CHANGED
@@ -448,6 +448,8 @@ def _build_GOT_vision(
|
|
448 |
image_size = 1024
|
449 |
vit_patch_size = 16
|
450 |
image_embedding_size = image_size // vit_patch_size
|
|
|
|
|
451 |
image_encoder=ImageEncoderViT(
|
452 |
depth=encoder_depth,
|
453 |
embed_dim=encoder_embed_dim,
|
@@ -461,8 +463,8 @@ def _build_GOT_vision(
|
|
461 |
global_attn_indexes=encoder_global_attn_indexes,
|
462 |
window_size=14,
|
463 |
out_chans=prompt_embed_dim,
|
464 |
-
)
|
465 |
-
|
466 |
|
467 |
return image_encoder
|
468 |
|
|
|
448 |
image_size = 1024
|
449 |
vit_patch_size = 16
|
450 |
image_embedding_size = image_size // vit_patch_size
|
451 |
+
device = torch.device('mps')
|
452 |
+
|
453 |
image_encoder=ImageEncoderViT(
|
454 |
depth=encoder_depth,
|
455 |
embed_dim=encoder_embed_dim,
|
|
|
463 |
global_attn_indexes=encoder_global_attn_indexes,
|
464 |
window_size=14,
|
465 |
out_chans=prompt_embed_dim,
|
466 |
+
).to(device)
|
467 |
+
|
468 |
|
469 |
return image_encoder
|
470 |
|