jadechoghari commited on
Commit
1b15d03
1 Parent(s): d2d1eb4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +93 -0
README.md CHANGED
@@ -26,6 +26,99 @@ Our method leverages the pre-trained SAM model with only marginal parameter incr
26
 
27
  <img width="1096" alt="image" src="figures/architecture.jpg">
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  ## Comparison of computational requirements
31
  <img width="720" alt="image" src='figures/Computational requirements.PNG'>
 
26
 
27
  <img width="1096" alt="image" src="figures/architecture.jpg">
28
 
29
+ **Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything).
30
+
31
+ # Model Details
32
+
33
+ The RobustSAM model is made up of 3 modules:
34
+ - The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used.
35
+ - The `PromptEncoder`: generates embeddings for points and bounding boxes
36
+ - The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed
37
+ - The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`.
38
+ # Usage
39
+
40
+
41
+ ## Prompted-Mask-Generation
42
+
43
+ ```python
44
+ from PIL import Image
45
+ import requests
46
+ from transformers import AutoProcessor, AutoModelForMaskGeneration
47
+
48
+ # load the RobustSAM model and processor
49
+ processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
50
+ model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
51
+
52
+ # load an image from a url
53
+ img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
54
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
55
+
56
+ # we define input points (2D localization of an object in the image)
57
+ input_points = [[[450, 600]]] # example point
58
+
59
+ ```
60
+
61
+
62
+ ```python
63
+ # process the image and input points
64
+ inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
65
+
66
+ # generate masks using the model
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
70
+ scores = outputs.iou_scores
71
+
72
+ ```
73
+ Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
74
+ For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example!
75
+
76
+ ## Automatic-Mask-Generation
77
+
78
+ The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points
79
+ which are all fed to the model.
80
+
81
+ The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument)
82
+ ```python
83
+ from transformers import pipeline
84
+
85
+ # initialize the pipeline for mask generation
86
+ generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
87
+
88
+ image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
89
+ outputs = generator(image_url, points_per_batch=256)
90
+ ```
91
+ Now to display the generated mask on the image:
92
+ ```python
93
+ import matplotlib.pyplot as plt
94
+ from PIL import Image
95
+ import numpy as np
96
+
97
+ # simple function to display the mask
98
+ def show_mask(mask, ax, random_color=False):
99
+ if random_color:
100
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
101
+ else:
102
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
103
+
104
+ # get the height and width from the mask
105
+ h, w = mask.shape[-2:]
106
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
107
+ ax.imshow(mask_image)
108
+
109
+ # display the original image
110
+ plt.imshow(np.array(raw_image))
111
+ ax = plt.gca()
112
+
113
+ # loop through the masks and display each one
114
+ for mask in outputs["masks"]:
115
+ show_mask(mask, ax=ax, random_color=True)
116
+
117
+ plt.axis("off")
118
+
119
+ # show the image with the masks
120
+ plt.show()
121
+ ```
122
 
123
  ## Comparison of computational requirements
124
  <img width="720" alt="image" src='figures/Computational requirements.PNG'>