Spaces:
Runtime error
Runtime error
Use weights in cache dir
Browse files- ViT-B-32.pt +0 -3
- app.py +13 -5
- sam_vit_h_4b8939.pth +0 -3
ViT-B-32.pt
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af
|
3 |
-
size 353976522
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
from functools import lru_cache
|
3 |
from random import randint
|
4 |
from typing import Any, Callable, Dict, List, Tuple
|
@@ -11,7 +12,10 @@ import PIL
|
|
11 |
import torch
|
12 |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
MODEL_TYPE = "default"
|
16 |
MAX_WIDTH = MAX_HEIGHT = 800
|
17 |
THRESHOLD = 0.05
|
@@ -20,6 +24,11 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
20 |
|
21 |
@lru_cache
|
22 |
def load_mask_generator() -> SamAutomaticMaskGenerator:
|
|
|
|
|
|
|
|
|
|
|
23 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
|
24 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
25 |
return mask_generator
|
@@ -27,10 +36,9 @@ def load_mask_generator() -> SamAutomaticMaskGenerator:
|
|
27 |
|
28 |
@lru_cache
|
29 |
def load_clip(
|
30 |
-
name: str = "ViT-B
|
31 |
) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
|
32 |
-
|
33 |
-
model, preprocess = clip.load(model_path, device=device)
|
34 |
return model.to(device), preprocess
|
35 |
|
36 |
|
@@ -63,7 +71,7 @@ def get_scores(crops: List[PIL.Image.Image], query: str) -> torch.Tensor:
|
|
63 |
def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
|
64 |
x, y, w, h = mask["bbox"]
|
65 |
masked = image * np.expand_dims(mask["segmentation"], -1)
|
66 |
-
crop = masked[y
|
67 |
if h > w:
|
68 |
top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
|
69 |
else:
|
|
|
1 |
import os
|
2 |
+
import urllib
|
3 |
from functools import lru_cache
|
4 |
from random import randint
|
5 |
from typing import Any, Callable, Dict, List, Tuple
|
|
|
12 |
import torch
|
13 |
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
14 |
|
15 |
+
|
16 |
+
CHECKPOINT_PATH = os.path.join(os.path.expanduser("~"), ".cache", "SAM")
|
17 |
+
CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
|
18 |
+
CHECKPOINT_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
19 |
MODEL_TYPE = "default"
|
20 |
MAX_WIDTH = MAX_HEIGHT = 800
|
21 |
THRESHOLD = 0.05
|
|
|
24 |
|
25 |
@lru_cache
|
26 |
def load_mask_generator() -> SamAutomaticMaskGenerator:
|
27 |
+
if not os.path.exists(CHECKPOINT_PATH):
|
28 |
+
os.makedirs(CHECKPOINT_PATH)
|
29 |
+
checkpoint = os.path.join(CHECKPOINT_PATH, CHECKPOINT_NAME)
|
30 |
+
if not os.path.exists(checkpoint):
|
31 |
+
urllib.request.urlretrieve(CHECKPOINT_URL, checkpoint)
|
32 |
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
|
33 |
mask_generator = SamAutomaticMaskGenerator(sam)
|
34 |
return mask_generator
|
|
|
36 |
|
37 |
@lru_cache
|
38 |
def load_clip(
|
39 |
+
name: str = "ViT-B/32",
|
40 |
) -> Tuple[torch.nn.Module, Callable[[PIL.Image.Image], torch.Tensor]]:
|
41 |
+
model, preprocess = clip.load(name, device=device)
|
|
|
42 |
return model.to(device), preprocess
|
43 |
|
44 |
|
|
|
71 |
def crop_image(image: np.ndarray, mask: Dict[str, Any]) -> PIL.Image.Image:
|
72 |
x, y, w, h = mask["bbox"]
|
73 |
masked = image * np.expand_dims(mask["segmentation"], -1)
|
74 |
+
crop = masked[y: y + h, x: x + w]
|
75 |
if h > w:
|
76 |
top, bottom, left, right = 0, 0, (h - w) // 2, (h - w) // 2
|
77 |
else:
|
sam_vit_h_4b8939.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
-
size 2564550879
|
|
|
|
|
|
|
|