|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
import torch |
|
import yaml |
|
from PIL import Image |
|
from segmenter_model import utils |
|
from segmenter_model.factory import create_segmenter |
|
from segmenter_model.fpn_picie import PanopticFPN |
|
from segmenter_model.utils import colorize_one, map2cs |
|
from torchvision import transforms |
|
|
|
|
|
WEIGHTS = './weights/segmenter_nusc.pth' |
|
FULL = True |
|
CACHE = True |
|
ALPHA = 0.5 |
|
|
|
|
|
def blend_images(bg, fg, alpha=ALPHA): |
|
fg = fg.convert('RGBA') |
|
bg = bg.convert('RGBA') |
|
blended = Image.blend(bg, fg, alpha=alpha) |
|
|
|
return blended |
|
|
|
|
|
def download_file_from_google_drive(destination=WEIGHTS): |
|
id = '1v6_d2KHzRROsjb_cgxU7jvmnGVDXeBia' |
|
|
|
def get_confirm_token(response): |
|
for key, value in response.cookies.items(): |
|
if key.startswith('download_warning'): |
|
return value |
|
|
|
return None |
|
|
|
def save_response_content(response, destination): |
|
CHUNK_SIZE = 32768 |
|
|
|
with open(destination, "wb") as f: |
|
for chunk in response.iter_content(CHUNK_SIZE): |
|
if chunk: |
|
f.write(chunk) |
|
|
|
URL = "https://docs.google.com/uc?export=download" |
|
|
|
session = requests.Session() |
|
|
|
response = session.get(URL, params={'id': id}, stream=True) |
|
token = get_confirm_token(response) |
|
|
|
if token: |
|
params = {'id': id, 'confirm': token} |
|
response = session.get(URL, params=params, stream=True) |
|
|
|
save_response_content(response, destination) |
|
|
|
|
|
def download_weights(): |
|
print('Downloading weights...') |
|
|
|
url = 'https://data.ciirc.cvut.cz/public/projects/2022DriveAndSegment/segmenter_nusc.pth' |
|
import urllib.request |
|
urllib.request.urlretrieve(url, WEIGHTS) |
|
|
|
|
|
def segment_segmenter(image, model, window_size, window_stride, encoder_features=False, decoder_features=False, |
|
no_upsample=False, batch_size=1): |
|
seg_pred = utils.inference( |
|
model, |
|
image, |
|
image.shape[-2:], |
|
window_size, |
|
window_stride, |
|
batch_size=batch_size, |
|
no_upsample=no_upsample, |
|
encoder_features=encoder_features, |
|
decoder_features=decoder_features |
|
) |
|
if not (encoder_features or decoder_features): |
|
seg_pred = seg_pred.argmax(1).unsqueeze(1) |
|
return seg_pred |
|
|
|
|
|
def remap(seg_pred, ignore=255): |
|
if 'nusc' in WEIGHTS.lower(): |
|
mapping = {0: 0, 13: 1, 2: 2, 7: 3, 17: 4, 20: 5, 8: 6, 12: 7, 26: 8, 14: 9, 22: 10, 11: 11, 6: 12, 27: 13, |
|
10: 14, 19: 15, 24: 16, 9: 17, 4: 18} |
|
else: |
|
mapping = {0: 0, 12: 1, 15: 2, 23: 3, 10: 4, 14: 5, 18: 6, 2: 7, 17: 8, 13: 9, 8: 10, 3: 11, 27: 12, 4: 13, |
|
25: 14, 24: 15, 6: 16, 22: 17, 28: 18} |
|
h, w = seg_pred.shape[-2:] |
|
seg_pred_remap = np.ones((h, w), dtype=np.uint8) * ignore |
|
for pseudo, gt in mapping.items(): |
|
whr = seg_pred == pseudo |
|
seg_pred_remap[whr] = gt |
|
return seg_pred_remap |
|
|
|
|
|
def create_model(resnet=False): |
|
weights_path = WEIGHTS |
|
variant_path = '{}_variant{}.yml'.format(weights_path, '_full' if FULL else '') |
|
|
|
print('Use weights {}'.format(weights_path)) |
|
print('Load variant from {}'.format(variant_path)) |
|
variant = yaml.load( |
|
open(variant_path, "r"), Loader=yaml.FullLoader |
|
) |
|
|
|
|
|
window_size = variant['inference_kwargs']["window_size"] |
|
window_stride = variant['inference_kwargs']["window_stride"] |
|
im_size = variant['inference_kwargs']["im_size"] |
|
|
|
net_kwargs = variant["net_kwargs"] |
|
if not resnet: |
|
net_kwargs['decoder']['dropout'] = 0. |
|
|
|
|
|
if resnet: |
|
model = PanopticFPN(arch=net_kwargs['backbone'], pretrain=net_kwargs['pretrain'], n_cls=net_kwargs['n_cls']) |
|
else: |
|
model = create_segmenter(net_kwargs) |
|
|
|
|
|
print('Load weights from {}'.format(weights_path)) |
|
weights = torch.load(weights_path, map_location=torch.device('cpu'))['model'] |
|
model.load_state_dict(weights, strict=True) |
|
|
|
model.eval() |
|
|
|
return model, window_size, window_stride, im_size |
|
|
|
|
|
download_weights() |
|
model, window_size, window_stride, im_size = create_model() |
|
|
|
|
|
def get_transformations(input_img): |
|
trans_list = [transforms.ToTensor()] |
|
|
|
shorter_input_size = min(input_img.size) |
|
|
|
|
|
|
|
trans_list.append(transforms.Resize(im_size)) |
|
|
|
trans_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) |
|
|
|
return transforms.Compose(trans_list) |
|
|
|
|
|
def predict(input_img): |
|
input_img_pil = Image.open(input_img) |
|
transform = get_transformations(input_img_pil) |
|
input_img = transform(input_img_pil) |
|
input_img = torch.unsqueeze(input_img, 0) |
|
|
|
with torch.no_grad(): |
|
segmentation = segment_segmenter(input_img, model, window_size, window_stride).squeeze().detach() |
|
segmentation_remap = remap(segmentation) |
|
|
|
drawing_pseudo = colorize_one(segmentation_remap) |
|
drawing_cs = map2cs(segmentation_remap) |
|
|
|
drawing_cs = transforms.ToPILImage()(drawing_cs).resize(input_img_pil.size) |
|
drawing_blend_cs = blend_images(input_img_pil, drawing_cs) |
|
|
|
drawing_pseudo = transforms.ToPILImage()(drawing_pseudo).resize(input_img_pil.size) |
|
drawing_blend_pseudo = blend_images(input_img_pil, drawing_pseudo) |
|
|
|
return drawing_blend_pseudo, drawing_blend_cs |
|
|
|
|
|
title = 'Drive&Segment' |
|
description = 'Gradio Demo accompanying paper "Drive&Segment: Unsupervised Semantic Segmentation of Urban Scenes via Cross-modal Distillation"\nBecause of the CPU-only inference, it might take up to 20s for large images.\nRight now, it uses the Segmenter model trained on nuScenes and with a simplified inference scheme (for the sake of speed). Please see description below the app for more details.' |
|
|
|
article = """ |
|
<h1 align="center">🚙📷 Drive&Segment: Unsupervised Semantic Segmentation of Urban Scenes via Cross-modal Distillation</h1> |
|
|
|
## 💫 Highlights |
|
|
|
- 🚫🔬 **Unsupervised semantic segmentation:** Drive&Segments proposes learning semantic segmentation in urban scenes without any manual annotation, just from |
|
the raw non-curated data collected by cars which, equipped with 📷 cameras and 💥 LiDAR sensors. |
|
- 📷💥 **Multi-modal training:** During the train time our method takes 📷 images and 💥 LiDAR scans as an input, and |
|
learns a semantic segmentation model *without using manual annotations*. |
|
- 📷 **Image-only inference:** During the inference time, Drive&Segments takes *only images* as an input. |
|
- 🏆 **State-of-the-art performance:** Our best single model based on Segmenter architecture achieves **21.8%** in mIoU on |
|
Cityscapes (without any fine-tuning). |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = [ |
|
'examples/100.jpeg', |
|
|
|
'examples/img1.jpg', |
|
'examples/snow1.jpg'] |
|
examples += ['examples/cs{}.jpg'.format(i) for i in range(3, 5)] |
|
|
|
iface = gr.Interface(predict, inputs=gr.Image(type='filepath'), title=title, description=description, |
|
article=article, |
|
|
|
outputs=[gr.Image(label="Pseudo segmentation", type="pil"), |
|
gr.Image(label="Mapping to Cityscapes", type="pil")], |
|
examples=examples, cache_examples=CACHE) |
|
|
|
|
|
|
|
|
|
|
|
iface.launch(enable_queue=True, inline=True) |
|
|