File size: 8,121 Bytes
bd0b6a6
 
67f67a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17a46f1
67f67a9
1b15d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67f67a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
---
library_name: transformers
---

# RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight)
#  Model Card for ViT Large (ViT-L) version

<a href="https://colab.research.google.com/drive/1mrOjUNFrfZ2vuTnWrfl9ebAQov3a9S6E?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
[![Huggingfaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/robustsam/robustsam/tree/main)

Official repository for RobustSAM: Segment Anything Robustly on Degraded Images

[Project Page](https://robustsam.github.io/) | [Paper](https://openaccess.thecvf.com/content/CVPR2024/html/Chen_RobustSAM_Segment_Anything_Robustly_on_Degraded_Images_CVPR_2024_paper.html) | [Video](https://www.youtube.com/watch?v=Awukqkbs6zM) | [Dataset](https://huggingface.co/robustsam/robustsam/tree/main/dataset)


## Updates
- July 2024: ✨ Training code, data and model checkpoints for different ViT backbones are released!
- June 2024: ✨ Inference code has been released!
- Feb 2024: ✨ RobustSAM was accepted into CVPR 2024!


## Introduction
Segment Anything Model (SAM) has emerged as a transformative approach in image segmentation, acclaimed for its robust zero-shot segmentation capabilities and flexible prompting system. Nonetheless, its performance is challenged by images with degraded quality. Addressing this limitation, we propose the Robust Segment Anything Model (RobustSAM), which enhances SAM's performance on low-quality images while preserving its promptability and zero-shot generalization.

Our method leverages the pre-trained SAM model with only marginal parameter increments and computational requirements. The additional parameters of RobustSAM can be optimized within 30 hours on eight GPUs, demonstrating its feasibility and practicality for typical research laboratories. We also introduce the Robust-Seg dataset, a collection of 688K image-mask pairs with different degradations designed to train and evaluate our model optimally. Extensive experiments across various segmentation tasks and datasets confirm RobustSAM's superior performance, especially under zero-shot conditions, underscoring its potential for extensive real-world application. Additionally, our method has been shown to effectively improve the performance of SAM-based downstream tasks such as single image dehazing and deblurring.

<img width="1096" alt="image" src="figures/architecture.jpg">

**Disclaimer**: Content from **this** model card has been written by the Hugging Face team, and parts of it were copy pasted from the original [SAM model card](https://github.com/facebookresearch/segment-anything).

# Model Details

The RobustSAM model is made up of 3 modules:
  - The `VisionEncoder`: a VIT based image encoder. It computes the image embeddings using attention on patches of the image. Relative Positional Embedding is used.
  - The `PromptEncoder`: generates embeddings for points and bounding boxes
  - The `MaskDecoder`: a two-ways transformer which performs cross attention between the image embedding and the point embeddings (->) and between the point embeddings and the image embeddings. The outputs are fed
  - The `Neck`: predicts the output masks based on the contextualized masks produced by the `MaskDecoder`.
# Usage


## Prompted-Mask-Generation

```python
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration

# load the RobustSAM model and processor
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")

# load an image from a url
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

# we define input points (2D localization of an object in the image)
input_points = [[[450, 600]]]  # example point

```


```python
# process the image and input points
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")

# generate masks using the model
with torch.no_grad():
    outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores

```
Among other arguments to generate masks, you can pass 2D locations on the approximate position of your object of interest, a bounding box wrapping the object of interest (the format should be x, y coordinate of the top right and bottom left point of the bounding box), a segmentation mask. At this time of writing, passing a text as input is not supported by the official model according to [the official repository](https://github.com/facebookresearch/segment-anything/issues/4#issuecomment-1497626844).
For more details, refer to this notebook, which shows a walk throught of how to use the model, with a visual example! 

## Automatic-Mask-Generation

The model can be used for generating segmentation masks in a "zero-shot" fashion, given an input image. The model is automatically prompt with a grid of `1024` points
which are all fed to the model. 

The pipeline is made for automatic mask generation. The following snippet demonstrates how easy you can run it (on any device! Simply feed the appropriate `points_per_batch` argument)
```python
from transformers import pipeline

# initialize the pipeline for mask generation
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)

image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
```
Now to display the generated mask on the image: 
```python
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

# simple function to display the mask
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    
    # get the height and width from the mask
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

# display the original image
plt.imshow(np.array(raw_image))
ax = plt.gca()

# loop through the masks and display each one
for mask in outputs["masks"]:
    show_mask(mask, ax=ax, random_color=True)

plt.axis("off")

# show the image with the masks
plt.show()
```

## Comparison of computational requirements
<img width="720" alt="image" src='figures/Computational requirements.PNG'>

## Visual Comparison
<table>
  <tr>
    <td>
      <img src="figures/gif_output/blur_back_n_forth.gif" width="380">
    </td>
    <td>
      <img src="figures/gif_output/haze_back_n_forth.gif" width="380">
    </td>
  </tr>
  <tr>
    <td>
      <img src="figures/gif_output/lowlight_back_n_forth.gif" width="380">
    </td>
    <td>
      <img src="figures/gif_output/rain_back_n_forth.gif" width="380">
    </td>
  </tr>
</table>

<img width="1096" alt="image" src='figures/qualitative_result.PNG'>

## Quantitative Comparison
### Seen dataset with synthetic degradation
<img width="720" alt="image" src='figures/seen_dataset_with_synthetic_degradation.PNG'>

### Unseen dataset with synthetic degradation
<img width="720" alt="image" src='figures/unseen_dataset_with_synthetic_degradation.PNG'>

### Unseen dataset with real degradation
<img width="600" alt="image" src='figures/unseen_dataset_with_real_degradation.PNG'>

## Reference
If you find this work useful, please consider citing us!
```python
@inproceedings{chen2024robustsam,
  title={RobustSAM: Segment Anything Robustly on Degraded Images},
  author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
  journal={CVPR},
  year={2024}
}
```


## Acknowledgements
We thank the authors of [SAM](https://github.com/facebookresearch/segment-anything) from which our repo is based off of.