velier commited on
Commit
4918746
1 Parent(s): 7f908b9

add to mps device in _build_GOT_vision

Browse files
Files changed (1) hide show
  1. got_vision_b.py +4 -2
got_vision_b.py CHANGED
@@ -448,6 +448,8 @@ def _build_GOT_vision(
448
  image_size = 1024
449
  vit_patch_size = 16
450
  image_embedding_size = image_size // vit_patch_size
 
 
451
  image_encoder=ImageEncoderViT(
452
  depth=encoder_depth,
453
  embed_dim=encoder_embed_dim,
@@ -461,8 +463,8 @@ def _build_GOT_vision(
461
  global_attn_indexes=encoder_global_attn_indexes,
462
  window_size=14,
463
  out_chans=prompt_embed_dim,
464
- )
465
-
466
 
467
  return image_encoder
468
 
 
448
  image_size = 1024
449
  vit_patch_size = 16
450
  image_embedding_size = image_size // vit_patch_size
451
+ device = torch.device('mps')
452
+
453
  image_encoder=ImageEncoderViT(
454
  depth=encoder_depth,
455
  embed_dim=encoder_embed_dim,
 
463
  global_attn_indexes=encoder_global_attn_indexes,
464
  window_size=14,
465
  out_chans=prompt_embed_dim,
466
+ ).to(device)
467
+
468
 
469
  return image_encoder
470