Spaces:
Runtime error
Runtime error
Add all mask drawing
Browse files- .gitignore +2 -0
- Makefile +8 -0
- app.py +87 -0
- examples/city.jpg +0 -0
- examples/dog.jpg +0 -0
- examples/food.jpg +0 -0
- examples/horse.jpg +0 -0
- requirements.txt +6 -0
- sam_vit_h_4b8939.pth +3 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
flagged
|
Makefile
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env:
|
2 |
+
conda create -n segment-anything python=3.9
|
3 |
+
|
4 |
+
setup:
|
5 |
+
pip install -r requirements.txt
|
6 |
+
|
7 |
+
run:
|
8 |
+
gradio app.py
|
app.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import PIL
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
from random import randint
|
6 |
+
import gradio as gr
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
11 |
+
from typing import List
|
12 |
+
|
13 |
+
CHECKPOINT_PATH = "sam_vit_h_4b8939.pth"
|
14 |
+
MODEL_TYPE = "default"
|
15 |
+
MAX_WIDTH = MAX_HEIGHT = 800
|
16 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
|
19 |
+
@lru_cache
|
20 |
+
def load_mask_generator(model_size: str = "large") -> SamAutomaticMaskGenerator:
|
21 |
+
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device)
|
22 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
23 |
+
return mask_generator
|
24 |
+
|
25 |
+
|
26 |
+
def adjust_image_size(image: np.ndarray) -> np.ndarray:
|
27 |
+
height, width = image.shape[:2]
|
28 |
+
if height > width:
|
29 |
+
if height > MAX_HEIGHT:
|
30 |
+
height, width = MAX_HEIGHT, int(MAX_HEIGHT / height * width)
|
31 |
+
else:
|
32 |
+
if width > MAX_WIDTH:
|
33 |
+
height, width = int(MAX_WIDTH / width * height), MAX_WIDTH
|
34 |
+
image = cv2.resize(image, (width, height))
|
35 |
+
print(image.shape)
|
36 |
+
return image
|
37 |
+
|
38 |
+
|
39 |
+
def draw_masks(
|
40 |
+
image: np.ndarray, masks: List[np.ndarray], alpha: float = 0.7
|
41 |
+
) -> np.ndarray:
|
42 |
+
for mask in masks:
|
43 |
+
color = [randint(127, 255) for _ in range(3)]
|
44 |
+
segmentation = mask["segmentation"]
|
45 |
+
|
46 |
+
# draw mask overlay
|
47 |
+
colored_seg = np.expand_dims(segmentation, 0).repeat(3, axis=0)
|
48 |
+
colored_seg = np.moveaxis(colored_seg, 0, -1)
|
49 |
+
masked = np.ma.MaskedArray(image, mask=colored_seg, fill_value=color)
|
50 |
+
image_overlay = masked.filled()
|
51 |
+
image = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
|
52 |
+
|
53 |
+
# draw contour
|
54 |
+
contours, _ = cv2.findContours(
|
55 |
+
np.uint8(segmentation), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
56 |
+
)
|
57 |
+
cv2.drawContours(image, contours, -1, (255, 0, 0), 2)
|
58 |
+
return image
|
59 |
+
|
60 |
+
|
61 |
+
def segment(image_path: str, query: str) -> PIL.ImageFile.ImageFile:
|
62 |
+
mask_generator = load_mask_generator()
|
63 |
+
# reduce the size to save gpu memory
|
64 |
+
image = adjust_image_size(cv2.imread(image_path))
|
65 |
+
masks = mask_generator.generate(image)
|
66 |
+
image = draw_masks(image, masks)
|
67 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
68 |
+
image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
|
69 |
+
return image
|
70 |
+
|
71 |
+
|
72 |
+
demo = gr.Interface(
|
73 |
+
fn=segment,
|
74 |
+
inputs=[gr.Image(type="filepath"), "text"],
|
75 |
+
outputs="image",
|
76 |
+
allow_flagging="never",
|
77 |
+
title="Segment Anything with CLIP",
|
78 |
+
examples=[
|
79 |
+
[os.path.join(os.path.dirname(__file__), "examples/dog.jpg"), ""],
|
80 |
+
[os.path.join(os.path.dirname(__file__), "examples/city.jpg"), ""],
|
81 |
+
[os.path.join(os.path.dirname(__file__), "examples/food.jpg"), ""],
|
82 |
+
[os.path.join(os.path.dirname(__file__), "examples/horse.jpg"), ""],
|
83 |
+
],
|
84 |
+
)
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
demo.launch()
|
examples/city.jpg
ADDED
examples/dog.jpg
ADDED
examples/food.jpg
ADDED
examples/horse.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.24.1
|
2 |
+
opencv-python==4.7.0.72
|
3 |
+
pycocotools==2.0.6
|
4 |
+
matplotlib==3.7.1
|
5 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
6 |
+
git+https://github.com/openai/CLIP.git
|
sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|