Spaces:
Runtime error
Runtime error
server file and req
Browse files- README.md +29 -1
- app.py +164 -0
- requirements.txt +6 -0
README.md
CHANGED
@@ -10,4 +10,32 @@ pinned: false
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
# Segment Anything Model from Facebook
|
14 |
+
|
15 |
+
This is an implementation of the Segment Anything model from Facebook using PyTorch. The model can be used for image segmentation tasks to separate foreground objects from the background.
|
16 |
+
|
17 |
+
## How to Use the Model
|
18 |
+
|
19 |
+
We have implemented an API using FastAPI and Uvicorn to provide an easy-to-use interface for the Segment Anything model. The API allows users to send image files to the model and receive the segmented images in response.
|
20 |
+
|
21 |
+
To use the model, follow these steps:
|
22 |
+
|
23 |
+
1. Clone this repository to your local machine.
|
24 |
+
2. Install the required packages by running `pip install -r requirements.txt`.
|
25 |
+
3. Start the API server by running `python server.py`.
|
26 |
+
4. Send a POST request to `http://localhost:8000/PATH_USED_IN_CODE` with the image file attached as form data.
|
27 |
+
|
28 |
+
The response from the API will be a JSON object containing the segmented image as a base64-encoded string.
|
29 |
+
|
30 |
+
## How the Model Works
|
31 |
+
|
32 |
+
The Segment Anything model uses a fully convolutional neural network to perform image segmentation. The model takes an image as input and outputs a segmentation map, where each pixel in the map is assigned a label indicating whether it belongs to the foreground or background.
|
33 |
+
|
34 |
+
The model is trained on a large dataset of annotated images using a binary cross-entropy loss function. During training, the weights of the network are adjusted to minimize the difference between the predicted segmentation map and the ground truth segmentation map.
|
35 |
+
|
36 |
+
## References
|
37 |
+
|
38 |
+
For more information about the Segment Anything model and its implementation, please refer to the following resources:
|
39 |
+
|
40 |
+
- [Facebook Research Paper on Segment Anything Model](https://arxiv.org/abs/2103.16629)
|
41 |
+
- [PyTorch Implementation of the Segment Anything Model](https://github.com/facebookresearch/detectron2/tree/main/projects/SegmentAny)
|
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, status, File, Form, UploadFile
|
2 |
+
from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
|
3 |
+
from starlette.responses import RedirectResponse
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
|
6 |
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
|
7 |
+
import numpy as np
|
8 |
+
from io import BytesIO
|
9 |
+
from PIL import Image
|
10 |
+
from base64 import b64encode, b64decode
|
11 |
+
|
12 |
+
def pil_image_to_base64(image):
|
13 |
+
buffered = BytesIO()
|
14 |
+
image.save(buffered, format="PNG")
|
15 |
+
img_str = b64encode(buffered.getvalue()).decode("utf-8")
|
16 |
+
return img_str
|
17 |
+
|
18 |
+
sam_checkpoint = "sam_vit_b_01ec64.pth" # "sam_vit_l_0b3195.pth" or "sam_vit_h_4b8939.pth"
|
19 |
+
model_type = "vit_b" # "vit_l" or "vit_h"
|
20 |
+
device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
|
22 |
+
print("Loading model")
|
23 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
|
24 |
+
print("Finishing loading")
|
25 |
+
predictor = SamPredictor(sam)
|
26 |
+
|
27 |
+
app = FastAPI(debug=True)
|
28 |
+
origins = [
|
29 |
+
"http://localhost",
|
30 |
+
"http://localhost:8000",
|
31 |
+
"http://127.0.0.1",
|
32 |
+
"http://127.0.0.1:8000",
|
33 |
+
"http://localhost:5173",
|
34 |
+
"http://127.0.0.1:5173",
|
35 |
+
]
|
36 |
+
|
37 |
+
app.add_middleware(
|
38 |
+
CORSMiddleware,
|
39 |
+
allow_origins=origins,
|
40 |
+
allow_credentials=True,
|
41 |
+
allow_methods=["*"],
|
42 |
+
allow_headers=["*"]
|
43 |
+
)
|
44 |
+
|
45 |
+
input_point = []
|
46 |
+
input_label = []
|
47 |
+
masks = []
|
48 |
+
mask_input = [None]
|
49 |
+
|
50 |
+
@app.post("/image")
|
51 |
+
async def process_images(
|
52 |
+
image: UploadFile = File(...)
|
53 |
+
):
|
54 |
+
global input_point, input_label, mask_input, masks
|
55 |
+
input_point = []
|
56 |
+
input_label = []
|
57 |
+
masks = []
|
58 |
+
# mask_input = [None]
|
59 |
+
|
60 |
+
# Read the image and mask data as bytes
|
61 |
+
image_data = await image.read()
|
62 |
+
|
63 |
+
image_data = BytesIO(image_data)
|
64 |
+
img = np.array(Image.open(image_data))
|
65 |
+
print("get image", img.shape)
|
66 |
+
# produce an image embedding by calling SamPredictor.set_image
|
67 |
+
predictor.set_image(img[:,:,:-1])
|
68 |
+
print("finish setting image")
|
69 |
+
# Return a JSON response
|
70 |
+
return JSONResponse(
|
71 |
+
content={
|
72 |
+
"message": "Images received successfully",
|
73 |
+
},
|
74 |
+
status_code=200,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
@app.post("/undo")
|
79 |
+
async def process_images():
|
80 |
+
global input_point, input_label, mask_input
|
81 |
+
input_point.pop()
|
82 |
+
input_label.pop()
|
83 |
+
masks.pop()
|
84 |
+
# mask_input.pop()
|
85 |
+
|
86 |
+
return JSONResponse(
|
87 |
+
content={
|
88 |
+
"message": "Clear successfully",
|
89 |
+
},
|
90 |
+
status_code=200,
|
91 |
+
)
|
92 |
+
|
93 |
+
@app.post("/click")
|
94 |
+
async def click_images(
|
95 |
+
x: int = Form(...), # horizontal
|
96 |
+
y: int = Form(...) # vertical
|
97 |
+
):
|
98 |
+
global input_point, input_label, mask_input
|
99 |
+
input_point.append([x, y])
|
100 |
+
input_label.append(1)
|
101 |
+
print("get click", x, y)
|
102 |
+
print("input_point", input_point)
|
103 |
+
print("input_label", input_label)
|
104 |
+
|
105 |
+
|
106 |
+
masks_, scores_, logits_ = predictor.predict(
|
107 |
+
point_coords=np.array([input_point[-1]]),
|
108 |
+
point_labels=np.array([input_label[-1]]),
|
109 |
+
# mask_input=mask_input[-1],
|
110 |
+
multimask_output=True, # SAM outputs 3 masks, we choose the one with highest score
|
111 |
+
)
|
112 |
+
|
113 |
+
# mask_input.append(logits[np.argmax(scores), :, :][None, :, :])
|
114 |
+
masks.append(masks_[np.argmax(scores_), :, :])
|
115 |
+
res = np.zeros(masks[0].shape)
|
116 |
+
for mask in masks:
|
117 |
+
res = np.logical_or(res, mask)
|
118 |
+
res = Image.fromarray(res)
|
119 |
+
# res.save("res.png")
|
120 |
+
|
121 |
+
# Return a JSON response
|
122 |
+
return JSONResponse(
|
123 |
+
content={
|
124 |
+
"masks": pil_image_to_base64(res),
|
125 |
+
"message": "Images processed successfully"
|
126 |
+
},
|
127 |
+
status_code=200,
|
128 |
+
)
|
129 |
+
|
130 |
+
@app.post("/rect")
|
131 |
+
async def rect_images(
|
132 |
+
start_x: int = Form(...), # horizontal
|
133 |
+
start_y: int = Form(...), # vertical
|
134 |
+
end_x: int = Form(...), # horizontal
|
135 |
+
end_y: int = Form(...) # vertical
|
136 |
+
):
|
137 |
+
masks_, _, _ = predictor.predict(
|
138 |
+
point_coords=None,
|
139 |
+
point_labels=None,
|
140 |
+
box=np.array([[start_x, start_y, end_x, end_y]]),
|
141 |
+
multimask_output=False
|
142 |
+
)
|
143 |
+
|
144 |
+
res = Image.fromarray(masks_[0])
|
145 |
+
# res.save("res.png")
|
146 |
+
|
147 |
+
# Return a JSON response
|
148 |
+
return JSONResponse(
|
149 |
+
content={
|
150 |
+
"masks": pil_image_to_base64(res),
|
151 |
+
"message": "Images processed successfully"
|
152 |
+
},
|
153 |
+
status_code=200,
|
154 |
+
)
|
155 |
+
|
156 |
+
@app.get('/')
|
157 |
+
def home():
|
158 |
+
return 'This is API for uses Segment-Anything Model from facebook. You can use it to segment anything.'
|
159 |
+
|
160 |
+
|
161 |
+
import uvicorn
|
162 |
+
|
163 |
+
if __name__ == '__main__':
|
164 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
numpy
|
3 |
+
uvicorn
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
git+https://github.com/facebookresearch/segment-anything.git
|