|
import os |
|
import requests |
|
|
|
|
|
os.environ["PYTORCH_JIT"] = "0" |
|
|
|
from einops import rearrange |
|
import gradio as gr |
|
import numpy as np |
|
import spaces |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image, ImageOps |
|
from transformers import AutoModel, CLIPImageProcessor |
|
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry |
|
from segment_anything.modeling.image_encoder import ImageEncoderViT |
|
|
|
|
|
class RADIOVenc(nn.Module): |
|
def __init__(self, radio: nn.Module, img_enc: ImageEncoderViT, img_size: int = 1024): |
|
super().__init__() |
|
self.radio = radio |
|
self.neck = img_enc.neck |
|
self.img_size = img_size |
|
self.dtype = radio.input_conditioner.dtype |
|
|
|
def forward(self, x: torch.Tensor): |
|
h, w = x.shape[-2:] |
|
|
|
if self.dtype is not None: |
|
x = x.to(dtype=self.dtype) |
|
|
|
with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.dtype is None): |
|
output = self.radio(x) |
|
features = output["sam"].features |
|
|
|
rows = h // 16 |
|
cols = w // 16 |
|
|
|
features = rearrange(features, 'b (h w) c -> b c h w', h=rows, w=cols) |
|
|
|
features = self.neck(features) |
|
|
|
return features |
|
|
|
|
|
def download_file(url, save_path): |
|
|
|
if os.path.exists(save_path): |
|
print(f"File already exists at {save_path}. Skipping download.") |
|
return |
|
|
|
print(f"Downloading from {url}") |
|
|
|
|
|
response = requests.get(url, stream=True) |
|
|
|
|
|
if response.status_code == 200: |
|
|
|
with open(save_path, 'wb') as file: |
|
|
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
file.write(chunk) |
|
print(f"File downloaded successfully and saved as {save_path}") |
|
else: |
|
print(f"Failed to download file. HTTP Status Code: {response.status_code}") |
|
|
|
|
|
hf_repo = "nvidia/RADIO-L" |
|
image_processor = CLIPImageProcessor.from_pretrained(hf_repo) |
|
|
|
model_version = "radio_v2.5-l" |
|
|
|
model = torch.hub.load( |
|
'NVlabs/RADIO', |
|
'radio_model', |
|
version=model_version, |
|
progress=True, |
|
skip_validation=True, |
|
adaptor_names='sam') |
|
model.eval() |
|
|
|
local_sam_checkpoint_path = "sam_vit_h_4b8939.pth" |
|
download_file("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", local_sam_checkpoint_path) |
|
sam = sam_model_registry["vit_h"](checkpoint=local_sam_checkpoint_path) |
|
model._patch_size = 16 |
|
sam.image_encoder = RADIOVenc(model, sam.image_encoder, img_size=1024) |
|
conditioner = model.make_preprocessor_external() |
|
sam.pixel_mean = conditioner.norm_mean * 255 |
|
sam.pixel_std = conditioner.norm_std * 255 |
|
|
|
|
|
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False): |
|
|
|
|
|
assert len(features.shape) == 2, "features should be (N, C)" |
|
reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2] |
|
colors = features @ reduction_mat |
|
if remove_first_component: |
|
colors_min = colors.min(dim=0).values |
|
colors_max = colors.max(dim=0).values |
|
tmp_colors = (colors - colors_min) / (colors_max - colors_min) |
|
fg_mask = tmp_colors[..., 0] < 0.2 |
|
reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2] |
|
colors = features @ reduction_mat |
|
else: |
|
fg_mask = torch.ones_like(colors[:, 0]).bool() |
|
d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values) |
|
mdev = torch.median(d, dim=0).values |
|
s = d / mdev |
|
try: |
|
rins = colors[fg_mask][s[:, 0] < m, 0] |
|
gins = colors[fg_mask][s[:, 1] < m, 1] |
|
bins = colors[fg_mask][s[:, 2] < m, 2] |
|
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) |
|
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) |
|
except: |
|
rins = colors |
|
gins = colors |
|
bins = colors |
|
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()]) |
|
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()]) |
|
|
|
return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat) |
|
|
|
|
|
def get_pca_map( |
|
feature_map: torch.Tensor, |
|
img_size, |
|
interpolation="bicubic", |
|
return_pca_stats=False, |
|
pca_stats=None, |
|
): |
|
""" |
|
feature_map: (1, h, w, C) is the feature map of a single image. |
|
""" |
|
if feature_map.shape[0] != 1: |
|
|
|
feature_map = feature_map[None] |
|
if pca_stats is None: |
|
reduct_mat, color_min, color_max = get_robust_pca( |
|
feature_map.reshape(-1, feature_map.shape[-1]) |
|
) |
|
else: |
|
reduct_mat, color_min, color_max = pca_stats |
|
pca_color = feature_map @ reduct_mat |
|
pca_color = (pca_color - color_min) / (color_max - color_min) |
|
pca_color = pca_color.clamp(0, 1) |
|
pca_color = F.interpolate( |
|
pca_color.permute(0, 3, 1, 2), |
|
size=img_size, |
|
mode=interpolation, |
|
).permute(0, 2, 3, 1) |
|
pca_color = pca_color.cpu().numpy().squeeze(0) |
|
if return_pca_stats: |
|
return pca_color, (reduct_mat, color_min, color_max) |
|
return pca_color |
|
|
|
|
|
def pad_image_to_multiple_of(image, multiple=16): |
|
|
|
width, height = image.size |
|
new_width = (width + multiple -1) // multiple * multiple |
|
new_height = (height + multiple -1) // multiple * multiple |
|
|
|
|
|
pad_width = new_width - width |
|
pad_height = new_height - height |
|
|
|
left = pad_width // 2 |
|
right = pad_width - left |
|
top = pad_height // 2 |
|
bottom = pad_height - top |
|
|
|
|
|
padded_image = ImageOps.expand(image, (left, top, right, bottom), fill='black') |
|
|
|
return padded_image |
|
|
|
|
|
def center_crop_resize(image, size=(1024, 1024)): |
|
|
|
width, height = image.size |
|
|
|
|
|
if width > height: |
|
new_width = height |
|
new_height = height |
|
left = (width - new_width) / 2 |
|
top = 0 |
|
right = (width + new_width) / 2 |
|
bottom = height |
|
else: |
|
new_width = width |
|
new_height = width |
|
left = 0 |
|
top = (height - new_height) / 2 |
|
right = width |
|
bottom = (height + new_height) / 2 |
|
|
|
|
|
image = image.crop((left, top, right, bottom)) |
|
|
|
|
|
image = image.resize(size, Image.LANCZOS) |
|
|
|
return image |
|
|
|
|
|
def visualize_anns(orig_image: np.ndarray, anns): |
|
if len(anns) == 0: |
|
return orig_image |
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
|
|
|
kernel = torch.ones(1, 1, 5, 5, dtype=torch.float32) |
|
|
|
|
|
mask = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4), dtype=np.float32) |
|
mask[:,:,3] = 0 |
|
for ann in sorted_anns: |
|
m = ann['segmentation'] |
|
color_mask = np.concatenate([np.random.random(3), [0.35]]) |
|
|
|
tm = torch.as_tensor(m).reshape(1, 1, *m.shape).float() |
|
cvtm = F.conv2d(tm, kernel, padding=2) |
|
|
|
border_mask = (cvtm < 25).flatten(0, 2).numpy() |
|
|
|
mask[m] = color_mask |
|
mask[m & border_mask, 3] *= 1.0 / 0.35 |
|
|
|
color, alpha = mask[..., :3], mask[..., -1:] |
|
|
|
orig_image = orig_image.astype(np.float32) / 255 |
|
overlay = alpha * color + (1 - alpha) * orig_image |
|
|
|
overlay = (overlay * 255).astype(np.uint8) |
|
return overlay |
|
|
|
|
|
|
|
@spaces.GPU |
|
def infer_radio(image): |
|
"""Define the function to generate the output.""" |
|
model.cuda() |
|
conditioner.cuda() |
|
sam.cuda() |
|
sam_generator = SamAutomaticMaskGenerator(sam, output_mode="binary_mask") |
|
|
|
|
|
padded_image=pad_image_to_multiple_of(image, multiple=256) |
|
width, height = padded_image.size |
|
pixel_values = image_processor(images=padded_image, return_tensors='pt').pixel_values |
|
pixel_values = pixel_values.to(torch.bfloat16).cuda() |
|
pixel_values = conditioner(pixel_values) |
|
|
|
_, features = model(pixel_values)["backbone"] |
|
|
|
num_rows = height // model.patch_size |
|
num_cols = width // model.patch_size |
|
|
|
features = features.detach() |
|
features = rearrange(features, 'b (h w) c -> b h w c', h=num_rows, w=num_cols).float() |
|
|
|
pca_viz = get_pca_map(features, (height, width), interpolation='bilinear') |
|
|
|
|
|
resized_image = center_crop_resize(image) |
|
image_array = np.array(image) |
|
print("image size", image_array.shape) |
|
|
|
masks = sam_generator.generate(image_array) |
|
overlay = visualize_anns(image_array, masks) |
|
|
|
return pca_viz, overlay, f"{features.shape}" |
|
|
|
|
|
|
|
title = """RADIO: Reduce All Domains Into One""" |
|
|
|
description = """ |
|
# RADIO |
|
|
|
[AM-RADIO](https://github.com/NVlabs/RADIO) is a framework to distill Large Vision Foundation models into a single one. |
|
RADIO, a new vision foundation model, excels across visual domains, serving as a superior replacement for vision backbones. |
|
Integrating CLIP variants, DINOv2, and SAM through distillation, it preserves unique features like text grounding and segmentation correspondence. |
|
Outperforming teachers in ImageNet zero-shot (+6.8%), kNN (+2.39%), and linear probing segmentation (+3.8%) and vision-language models (LLaVa 1.5 up to 1.5%), it scales to any resolution, supports non-square images. |
|
|
|
# Instructions |
|
|
|
Paste an image into the input box or pick one from the gallery of examples and then click the "Submit" button. |
|
The RADIO backbone features are processed with a PCA projection to 3 channels and displayed as an RGB channels. |
|
The SAM features are processed using the SAM decoder and shown as an overlay on top of the input image. |
|
""" |
|
|
|
inputs = [ |
|
gr.Image(type="pil") |
|
] |
|
|
|
outputs = [ |
|
gr.Image(label="PCA Feature Visalization"), |
|
gr.Image(label="SAM Masks"), |
|
gr.Textbox(label="Feature Shape"), |
|
] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=infer_radio, |
|
inputs=inputs, |
|
examples="./samples/", |
|
outputs=outputs, |
|
title=title, |
|
description=description, |
|
cache_examples=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|