curt-park commited on
Commit
8760721
1 Parent(s): ce58c9d

Use weights in cache dir

Browse files
Files changed (3) hide show
  1. ViT-B-32.pt +0 -3
  2. app.py +13 -5
  3. sam_vit_h_4b8939.pth +0 -3
ViT-B-32.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
3
- size 353976522
 
 
 
 
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from functools import lru_cache
3
  from random import randint
4
  from typing import Any, Callable, Dict, List, Tuple
@@ -11,7 +12,10 @@ import PIL
11
  import torch
12
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
13
 
14
- CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
 
 
 
15
  MODEL_TYPE = "default"
16
  MAX_WIDTH = MAX_HEIGHT = 800
17
  THRESHOLD = 0.05
@@ -20,6 +24,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  @lru_cache
22
  def load_mask_generator() -> SamAutomaticMaskGenerator:
 
 
 
 
 
23
  sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
24
  mask_generator = SamAutomaticMaskGenerator(sam)
25
  return mask_generator
@@ -27,10 +36,9 @@ def load_mask_generator() -> SamAutomaticMaskGenerator:
27
 
28
  @lru_cache
29
  def load_clip(
30
- name: str = "ViT-B-32.pt",
31
  ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
32
- model_path = os.path.join(".", name)
33
- model, preprocess = clip.load(model_path, device=device)
34
  return model.to(device), preprocess
35
 
36
 
@@ -63,7 +71,7 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
63
  def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
64
  x, y, w, h = mask["bbox"]
65
  masked = image * np.expand_dims(mask["segmentation"], -1)
66
- crop = masked[y : y + h, x : x + w]
67
  if h > w:
68
  top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
69
  else:
 
1
  import os
2
+ import urllib
3
  from functools import lru_cache
4
  from random import randint
5
  from typing import Any, Callable, Dict, List, Tuple
 
12
  import torch
13
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
14
 
15
+
16
+ CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
17
+ CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
18
+ CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
19
  MODEL_TYPE = "default"
20
  MAX_WIDTH = MAX_HEIGHT = 800
21
  THRESHOLD = 0.05
 
24
 
25
  @lru_cache
26
  def load_mask_generator() -> SamAutomaticMaskGenerator:
27
+ if not os.path.exists(CHECKPOINT_PATH):
28
+ os.makedirs(CHECKPOINT_PATH)
29
+ checkpoint = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME)
30
+ if not os.path.exists(checkpoint):
31
+ urllib.request.urlretrieve(CHECKPOINT_URL, checkpoint)
32
  sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
33
  mask_generator = SamAutomaticMaskGenerator(sam)
34
  return mask_generator
 
36
 
37
  @lru_cache
38
  def load_clip(
39
+ name: str = "ViT-B/32",
40
  ) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
41
+ model, preprocess = clip.load(name, device=device)
 
42
  return model.to(device), preprocess
43
 
44
 
 
71
  def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
72
  x, y, w, h = mask["bbox"]
73
  masked = image * np.expand_dims(mask["segmentation"], -1)
74
+ crop = masked[y: y + h, x: x + w]
75
  if h > w:
76
  top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
77
  else:
sam_vit_h_4b8939.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
- size 2564550879