DaS / app.py
vobecant
Initial commit.
ff9b5a2
import gradio as gr
import numpy as np
import requests
import torch
import yaml
from PIL import Image
from segmenter_model import utils
from segmenter_model.factory import create_segmenter
from segmenter_model.fpn_picie import PanopticFPN
from segmenter_model.utils import colorize_one, map2cs
from torchvision import transforms
# WEIGHTS = './weights/segmenter.pth
WEIGHTS = './weights/segmenter_nusc.pth'
FULL = True
CACHE = True
ALPHA = 0.5
def blend_images(bg, fg, alpha=ALPHA):
fg = fg.convert('RGBA')
bg = bg.convert('RGBA')
blended = Image.blend(bg, fg, alpha=alpha)
return blended
def download_file_from_google_drive(destination=WEIGHTS):
id = '1v6_d2KHzRROsjb_cgxU7jvmnGVDXeBia'
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params={'id': id}, stream=True)
token = get_confirm_token(response)
if token:
params = {'id': id, 'confirm': token}
response = session.get(URL, params=params, stream=True)
save_response_content(response, destination)
def download_weights():
print('Downloading weights...')
# if not os.path.exists(WEIGHTS):
url = 'https://data.ciirc.cvut.cz/public/projects/2022DriveAndSegment/segmenter_nusc.pth'
import urllib.request
urllib.request.urlretrieve(url, WEIGHTS)
def segment_segmenter(image, model, window_size, window_stride, encoder_features=False, decoder_features=False,
no_upsample=False, batch_size=1):
seg_pred = utils.inference(
model,
image,
image.shape[-2:],
window_size,
window_stride,
batch_size=batch_size,
no_upsample=no_upsample,
encoder_features=encoder_features,
decoder_features=decoder_features
)
if not (encoder_features or decoder_features):
seg_pred = seg_pred.argmax(1).unsqueeze(1)
return seg_pred
def remap(seg_pred, ignore=255):
if 'nusc' in WEIGHTS.lower():
mapping = {0: 0, 13: 1, 2: 2, 7: 3, 17: 4, 20: 5, 8: 6, 12: 7, 26: 8, 14: 9, 22: 10, 11: 11, 6: 12, 27: 13,
10: 14, 19: 15, 24: 16, 9: 17, 4: 18}
else:
mapping = {0: 0, 12: 1, 15: 2, 23: 3, 10: 4, 14: 5, 18: 6, 2: 7, 17: 8, 13: 9, 8: 10, 3: 11, 27: 12, 4: 13,
25: 14, 24: 15, 6: 16, 22: 17, 28: 18}
h, w = seg_pred.shape[-2:]
seg_pred_remap = np.ones((h, w), dtype=np.uint8) * ignore
for pseudo, gt in mapping.items():
whr = seg_pred == pseudo
seg_pred_remap[whr] = gt
return seg_pred_remap
def create_model(resnet=False):
weights_path = WEIGHTS
variant_path = '{}_variant{}.yml'.format(weights_path, '_full' if FULL else '')
print('Use weights {}'.format(weights_path))
print('Load variant from {}'.format(variant_path))
variant = yaml.load(
open(variant_path, "r"), Loader=yaml.FullLoader
)
# TODO: parse hyperparameters
window_size = variant['inference_kwargs']["window_size"]
window_stride = variant['inference_kwargs']["window_stride"]
im_size = variant['inference_kwargs']["im_size"]
net_kwargs = variant["net_kwargs"]
if not resnet:
net_kwargs['decoder']['dropout'] = 0.
# TODO: create model
if resnet:
model = PanopticFPN(arch=net_kwargs['backbone'], pretrain=net_kwargs['pretrain'], n_cls=net_kwargs['n_cls'])
else:
model = create_segmenter(net_kwargs)
# TODO: load weights
print('Load weights from {}'.format(weights_path))
weights = torch.load(weights_path, map_location=torch.device('cpu'))['model']
model.load_state_dict(weights, strict=True)
model.eval()
return model, window_size, window_stride, im_size
download_weights()
model, window_size, window_stride, im_size = create_model()
def get_transformations(input_img):
trans_list = [transforms.ToTensor()]
shorter_input_size = min(input_img.size)
# if im_size != 1024 or shorter_input_size < im_size:
# trans_list.append(transforms.Resize(im_size))
trans_list.append(transforms.Resize(im_size))
trans_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
return transforms.Compose(trans_list)
def predict(input_img):
input_img_pil = Image.open(input_img)
transform = get_transformations(input_img_pil)
input_img = transform(input_img_pil)
input_img = torch.unsqueeze(input_img, 0)
with torch.no_grad():
segmentation = segment_segmenter(input_img, model, window_size, window_stride).squeeze().detach()
segmentation_remap = remap(segmentation)
drawing_pseudo = colorize_one(segmentation_remap)
drawing_cs = map2cs(segmentation_remap)
drawing_cs = transforms.ToPILImage()(drawing_cs).resize(input_img_pil.size)
drawing_blend_cs = blend_images(input_img_pil, drawing_cs)
drawing_pseudo = transforms.ToPILImage()(drawing_pseudo).resize(input_img_pil.size)
drawing_blend_pseudo = blend_images(input_img_pil, drawing_pseudo)
return drawing_blend_pseudo, drawing_blend_cs
title = 'Drive&Segment'
description = 'Gradio Demo accompanying paper "Drive&Segment: Unsupervised Semantic Segmentation of Urban Scenes via Cross-modal Distillation"\nBecause of the CPU-only inference, it might take up to 20s for large images.\nRight now, it uses the Segmenter model trained on nuScenes and with a simplified inference scheme (for the sake of speed). Please see description below the app for more details.'
# article = "<p style='text-align: center'><a href='https://vobecant.github.io/DriveAndSegment/' target='_blank'>Project Page</a> | <a href='https://github.com/vobecant/DriveAndSegment' target='_blank'>Github</a></p>"
article = """
<h1 align="center">🚙📷 Drive&Segment: Unsupervised Semantic Segmentation of Urban Scenes via Cross-modal Distillation</h1>
## 💫 Highlights
- 🚫🔬 **Unsupervised semantic segmentation:** Drive&Segments proposes learning semantic segmentation in urban scenes without any manual annotation, just from
the raw non-curated data collected by cars which, equipped with 📷 cameras and 💥 LiDAR sensors.
- 📷💥 **Multi-modal training:** During the train time our method takes 📷 images and 💥 LiDAR scans as an input, and
learns a semantic segmentation model *without using manual annotations*.
- 📷 **Image-only inference:** During the inference time, Drive&Segments takes *only images* as an input.
- 🏆 **State-of-the-art performance:** Our best single model based on Segmenter architecture achieves **21.8%** in mIoU on
Cityscapes (without any fine-tuning).
"""
# ![teaser](https://drive.google.com/uc?export=view&id=1MkQmAfBPUomJDUikLhM_Wk8VUNekPb91)
# <h2 align="center">
# <a href="https://vobecant.github.io/DriveAndSegment">project page</a> |
# <a href="http://arxiv.org/abs/2203.11160">arXiv</a> |
# <a href="https://huggingface.co/spaces/vobecant/DaS">Gradio</a> |
# <a href="https://colab.research.google.com/drive/126tBVYbt1s0STyv8DKhmLoHKpvWcv33H?usp=sharing">Colab</a> |
# <a href="https://www.youtube.com/watch?v=B9LK-Fxu7ao">video</a>
# </h2>
# description += """
# ## 📺 Examples
#
# ### **Pseudo** segmentation.
#
# Example of **pseudo** segmentation.
#
# ![](https://drive.google.com/uc?export=view&id=1n27_zAMBAc2e8hEzh5FTDNM-V6zKAE4p)
# ### Cityscapes segmentation.
#
# Two examples of pseudo segmentation mapped to the 19 ground-truth classes of the Cityscapes dataset by using Hungarian
# algorithm.
#
# ![](https://drive.google.com/uc?export=view&id=1vHF2DugjXr4FdXX3gW65GRPArNL5urEH)
# ![](https://drive.google.com/uc?export=view&id=1WI_5lmF_YoVFXdWDnPT29rhPnlylh7QV)
# """
examples = [ # 'examples/img5.jpeg',
'examples/100.jpeg',
# 'examples/39076.jpeg',
'examples/img1.jpg',
'examples/snow1.jpg']
examples += ['examples/cs{}.jpg'.format(i) for i in range(3, 5)]
iface = gr.Interface(predict, inputs=gr.Image(type='filepath'), title=title, description=description,
article=article,
# theme='dark',
outputs=[gr.Image(label="Pseudo segmentation", type="pil"),
gr.Image(label="Mapping to Cityscapes", type="pil")],
examples=examples, cache_examples=CACHE)
# iface = gr.Interface(predict, gr.inputs.Image(type='filepath'),
# "image", title=title, description=description,
# examples=examples)
# iface.launch(show_error=True, share=True)
iface.launch(enable_queue=True, inline=True)