mrtlive commited on
Commit
292c73a
1 Parent(s): ea0a437

server file and req

Browse files
Files changed (3) hide show
  1. README.md +29 -1
  2. app.py +164 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -10,4 +10,32 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  license: apache-2.0
11
  ---
12
 
13
+ # Segment Anything Model from Facebook
14
+
15
+ This is an implementation of the Segment Anything model from Facebook using PyTorch. The model can be used for image segmentation tasks to separate foreground objects from the background.
16
+
17
+ ## How to Use the Model
18
+
19
+ We have implemented an API using FastAPI and Uvicorn to provide an easy-to-use interface for the Segment Anything model. The API allows users to send image files to the model and receive the segmented images in response.
20
+
21
+ To use the model, follow these steps:
22
+
23
+ 1. Clone this repository to your local machine.
24
+ 2. Install the required packages by running `pip install -r requirements.txt`.
25
+ 3. Start the API server by running `python server.py`.
26
+ 4. Send a POST request to `http://localhost:8000/PATH_USED_IN_CODE` with the image file attached as form data.
27
+
28
+ The response from the API will be a JSON object containing the segmented image as a base64-encoded string.
29
+
30
+ ## How the Model Works
31
+
32
+ The Segment Anything model uses a fully convolutional neural network to perform image segmentation. The model takes an image as input and outputs a segmentation map, where each pixel in the map is assigned a label indicating whether it belongs to the foreground or background.
33
+
34
+ The model is trained on a large dataset of annotated images using a binary cross-entropy loss function. During training, the weights of the network are adjusted to minimize the difference between the predicted segmentation map and the ground truth segmentation map.
35
+
36
+ ## References
37
+
38
+ For more information about the Segment Anything model and its implementation, please refer to the following resources:
39
+
40
+ - [Facebook Research Paper on Segment Anything Model](https://arxiv.org/abs/2103.16629)
41
+ - [PyTorch Implementation of the Segment Anything Model](https://github.com/facebookresearch/detectron2/tree/main/projects/SegmentAny)
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, status, File, Form, UploadFile
2
+ from fastapi.responses import HTMLResponse, FileResponse, JSONResponse
3
+ from starlette.responses import RedirectResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+
6
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
7
+ import numpy as np
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from base64 import b64encode, b64decode
11
+
12
+ def pil_image_to_base64(image):
13
+ buffered = BytesIO()
14
+ image.save(buffered, format="PNG")
15
+ img_str = b64encode(buffered.getvalue()).decode("utf-8")
16
+ return img_str
17
+
18
+ sam_checkpoint = "sam_vit_b_01ec64.pth" # "sam_vit_l_0b3195.pth" or "sam_vit_h_4b8939.pth"
19
+ model_type = "vit_b" # "vit_l" or "vit_h"
20
+ device = "cpu" # "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ print("Loading model")
23
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
24
+ print("Finishing loading")
25
+ predictor = SamPredictor(sam)
26
+
27
+ app = FastAPI(debug=True)
28
+ origins = [
29
+ "http://localhost",
30
+ "http://localhost:8000",
31
+ "http://127.0.0.1",
32
+ "http://127.0.0.1:8000",
33
+ "http://localhost:5173",
34
+ "http://127.0.0.1:5173",
35
+ ]
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=origins,
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"]
43
+ )
44
+
45
+ input_point = []
46
+ input_label = []
47
+ masks = []
48
+ mask_input = [None]
49
+
50
+ @app.post("/image")
51
+ async def process_images(
52
+ image: UploadFile = File(...)
53
+ ):
54
+ global input_point, input_label, mask_input, masks
55
+ input_point = []
56
+ input_label = []
57
+ masks = []
58
+ # mask_input = [None]
59
+
60
+ # Read the image and mask data as bytes
61
+ image_data = await image.read()
62
+
63
+ image_data = BytesIO(image_data)
64
+ img = np.array(Image.open(image_data))
65
+ print("get image", img.shape)
66
+ # produce an image embedding by calling SamPredictor.set_image
67
+ predictor.set_image(img[:,:,:-1])
68
+ print("finish setting image")
69
+ # Return a JSON response
70
+ return JSONResponse(
71
+ content={
72
+ "message": "Images received successfully",
73
+ },
74
+ status_code=200,
75
+ )
76
+
77
+
78
+ @app.post("/undo")
79
+ async def process_images():
80
+ global input_point, input_label, mask_input
81
+ input_point.pop()
82
+ input_label.pop()
83
+ masks.pop()
84
+ # mask_input.pop()
85
+
86
+ return JSONResponse(
87
+ content={
88
+ "message": "Clear successfully",
89
+ },
90
+ status_code=200,
91
+ )
92
+
93
+ @app.post("/click")
94
+ async def click_images(
95
+ x: int = Form(...), # horizontal
96
+ y: int = Form(...) # vertical
97
+ ):
98
+ global input_point, input_label, mask_input
99
+ input_point.append([x, y])
100
+ input_label.append(1)
101
+ print("get click", x, y)
102
+ print("input_point", input_point)
103
+ print("input_label", input_label)
104
+
105
+
106
+ masks_, scores_, logits_ = predictor.predict(
107
+ point_coords=np.array([input_point[-1]]),
108
+ point_labels=np.array([input_label[-1]]),
109
+ # mask_input=mask_input[-1],
110
+ multimask_output=True, # SAM outputs 3 masks, we choose the one with highest score
111
+ )
112
+
113
+ # mask_input.append(logits[np.argmax(scores), :, :][None, :, :])
114
+ masks.append(masks_[np.argmax(scores_), :, :])
115
+ res = np.zeros(masks[0].shape)
116
+ for mask in masks:
117
+ res = np.logical_or(res, mask)
118
+ res = Image.fromarray(res)
119
+ # res.save("res.png")
120
+
121
+ # Return a JSON response
122
+ return JSONResponse(
123
+ content={
124
+ "masks": pil_image_to_base64(res),
125
+ "message": "Images processed successfully"
126
+ },
127
+ status_code=200,
128
+ )
129
+
130
+ @app.post("/rect")
131
+ async def rect_images(
132
+ start_x: int = Form(...), # horizontal
133
+ start_y: int = Form(...), # vertical
134
+ end_x: int = Form(...), # horizontal
135
+ end_y: int = Form(...) # vertical
136
+ ):
137
+ masks_, _, _ = predictor.predict(
138
+ point_coords=None,
139
+ point_labels=None,
140
+ box=np.array([[start_x, start_y, end_x, end_y]]),
141
+ multimask_output=False
142
+ )
143
+
144
+ res = Image.fromarray(masks_[0])
145
+ # res.save("res.png")
146
+
147
+ # Return a JSON response
148
+ return JSONResponse(
149
+ content={
150
+ "masks": pil_image_to_base64(res),
151
+ "message": "Images processed successfully"
152
+ },
153
+ status_code=200,
154
+ )
155
+
156
+ @app.get('/')
157
+ def home():
158
+ return 'This is API for uses Segment-Anything Model from facebook. You can use it to segment anything.'
159
+
160
+
161
+ import uvicorn
162
+
163
+ if __name__ == '__main__':
164
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ numpy
3
+ uvicorn
4
+ torch
5
+ torchvision
6
+ git+https://github.com/facebookresearch/segment-anything.git