jadechoghari
commited on
Commit
•
1b15d03
1
Parent(s):
d2d1eb4
Update README.md
Browse files
README.md
CHANGED
@@ -26,6 +26,99 @@ Our method leverages the pre-trained SAM model with only marginal parameter incr
|
|
26 |
|
27 |
<img width="1096" alt="image" src="figures/architecture.jpg">
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
## Comparison of computational requirements
|
31 |
<img width="720" alt="image" src='figures/Computational requirements.PNG'>
|
|
|
26 |
|
27 |
<img width="1096" alt="image" src="figures/architecture.jpg">
|
28 |
|
29 |
+
**Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything).
|
30 |
+
|
31 |
+
# Model Details
|
32 |
+
|
33 |
+
The RobustSAM model is made up of 3 modules:
|
34 |
+
- The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used.
|
35 |
+
- The `PromptEncoder`: generates embeddings for points and bounding boxes
|
36 |
+
- The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed
|
37 |
+
- The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`.
|
38 |
+
# Usage
|
39 |
+
|
40 |
+
|
41 |
+
## Prompted-Mask-Generation
|
42 |
+
|
43 |
+
```python
|
44 |
+
from PIL import Image
|
45 |
+
import requests
|
46 |
+
from transformers import AutoProcessor, AutoModelForMaskGeneration
|
47 |
+
|
48 |
+
# load the RobustSAM model and processor
|
49 |
+
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
|
50 |
+
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
|
51 |
+
|
52 |
+
# load an image from a url
|
53 |
+
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
54 |
+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
55 |
+
|
56 |
+
# we define input points (2D localization of an object in the image)
|
57 |
+
input_points = [[[450, 600]]] # example point
|
58 |
+
|
59 |
+
```
|
60 |
+
|
61 |
+
|
62 |
+
```python
|
63 |
+
# process the image and input points
|
64 |
+
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
|
65 |
+
|
66 |
+
# generate masks using the model
|
67 |
+
with torch.no_grad():
|
68 |
+
outputs = model(**inputs)
|
69 |
+
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
|
70 |
+
scores = outputs.iou_scores
|
71 |
+
|
72 |
+
```
|
73 |
+
Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
|
74 |
+
For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example!
|
75 |
+
|
76 |
+
## Automatic-Mask-Generation
|
77 |
+
|
78 |
+
The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points
|
79 |
+
which are all fed to the model.
|
80 |
+
|
81 |
+
The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument)
|
82 |
+
```python
|
83 |
+
from transformers import pipeline
|
84 |
+
|
85 |
+
# initialize the pipeline for mask generation
|
86 |
+
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
|
87 |
+
|
88 |
+
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
|
89 |
+
outputs = generator(image_url, points_per_batch=256)
|
90 |
+
```
|
91 |
+
Now to display the generated mask on the image:
|
92 |
+
```python
|
93 |
+
import matplotlib.pyplot as plt
|
94 |
+
from PIL import Image
|
95 |
+
import numpy as np
|
96 |
+
|
97 |
+
# simple function to display the mask
|
98 |
+
def show_mask(mask, ax, random_color=False):
|
99 |
+
if random_color:
|
100 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
101 |
+
else:
|
102 |
+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
103 |
+
|
104 |
+
# get the height and width from the mask
|
105 |
+
h, w = mask.shape[-2:]
|
106 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
107 |
+
ax.imshow(mask_image)
|
108 |
+
|
109 |
+
# display the original image
|
110 |
+
plt.imshow(np.array(raw_image))
|
111 |
+
ax = plt.gca()
|
112 |
+
|
113 |
+
# loop through the masks and display each one
|
114 |
+
for mask in outputs["masks"]:
|
115 |
+
show_mask(mask, ax=ax, random_color=True)
|
116 |
+
|
117 |
+
plt.axis("off")
|
118 |
+
|
119 |
+
# show the image with the masks
|
120 |
+
plt.show()
|
121 |
+
```
|
122 |
|
123 |
## Comparison of computational requirements
|
124 |
<img width="720" alt="image" src='figures/Computational requirements.PNG'>
|