GroupViT / app.py
xvjiarui's picture
replace demo
a0899f8
# Modified from the implementation of https://huggingface.co/akhaliq
import os
import sys
os.system("git clone https://github.com/NVlabs/GroupViT")
sys.path.insert(0, 'GroupViT')
import os.path as osp
from collections import namedtuple
import gradio as gr
import mmcv
import numpy as np
import torch
from datasets import build_text_transform
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.image import tensor2imgs
from mmcv.parallel import collate, scatter
from models import build_model
from omegaconf import read_write
from segmentation.datasets import (COCOObjectDataset, PascalContextDataset,
PascalVOCDataset)
from segmentation.evaluation import (GROUP_PALETTE, build_seg_demo_pipeline,
build_seg_inference)
from utils import get_config, load_checkpoint
import shutil
if not osp.exists('GroupViT/hg_demo'):
shutil.copytree('demo/', 'GroupViT/hg_demo/')
os.chdir('GroupViT')
# checkpoint_url = 'https://github.com/xvjiarui/GroupViT-1/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-74d335e6.pth'
checkpoint_url = 'https://github.com/xvjiarui/GroupViT/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-879422e0.pth'
cfg_path = 'configs/group_vit_gcc_yfcc_30e.yml'
output_dir = 'demo/output'
device = 'cpu'
# vis_modes = ['first_group', 'final_group', 'input_pred_label']
vis_modes = ['input_pred_label', 'final_group']
output_labels = ['segmentation map', 'groups']
dataset_options = ['Pascal VOC', 'Pascal Context', 'COCO']
examples = [['Pascal VOC', '', 'hg_demo/voc.jpg'],
['Pascal Context', '', 'hg_demo/ctx.jpg'],
['COCO', '', 'hg_demo/coco.jpg']]
PSEUDO_ARGS = namedtuple('PSEUDO_ARGS',
['cfg', 'opts', 'resume', 'vis', 'local_rank'])
args = PSEUDO_ARGS(
cfg=cfg_path, opts=[], resume=checkpoint_url, vis=vis_modes, local_rank=0)
cfg = get_config(args)
with read_write(cfg):
cfg.evaluate.eval_only = True
model = build_model(cfg.model)
model = revert_sync_batchnorm(model)
model.to(device)
model.eval()
load_checkpoint(cfg, model, None, None)
text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
test_pipeline = build_seg_demo_pipeline()
def inference(dataset, additional_classes, input_img):
if dataset == 'voc' or dataset == 'Pascal VOC':
dataset_class = PascalVOCDataset
seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
elif dataset == 'coco' or dataset == 'COCO':
dataset_class = COCOObjectDataset
seg_cfg = 'segmentation/configs/_base_/datasets/coco.py'
elif dataset == 'context' or dataset == 'Pascal Context':
dataset_class = PascalContextDataset
seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
with read_write(cfg):
cfg.evaluate.seg.cfg = seg_cfg
cfg.evaluate.seg.opts = ['test_cfg.mode=whole']
dataset_cfg = mmcv.Config()
dataset_cfg.CLASSES = list(dataset_class.CLASSES)
dataset_cfg.PALETTE = dataset_class.PALETTE.copy()
if len(additional_classes) > 0:
additional_classes = additional_classes.split(',')
additional_classes = list(
set(additional_classes) - set(dataset_cfg.CLASSES))
dataset_cfg.CLASSES.extend(additional_classes)
dataset_cfg.PALETTE.extend(GROUP_PALETTE[np.random.choice(
list(range(len(GROUP_PALETTE))), len(additional_classes))])
seg_model = build_seg_inference(model, dataset_cfg, text_transform,
cfg.evaluate.seg)
device = next(seg_model.parameters()).device
# prepare data
data = dict(img=input_img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(seg_model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
with torch.no_grad():
result = seg_model(return_loss=False, rescale=False, **data)
img_tensor = data['img'][0]
img_metas = data['img_metas'][0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
out_file_dict = dict()
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
# ori_h, ori_w = img_meta['ori_shape'][:-1]
# short_side = 448
# if ori_h > ori_w:
# new_h, new_w = ori_h * short_side//ori_w , short_side
# else:
# new_w, new_h = ori_w * short_side//ori_h , short_side
# img_show = mmcv.imresize(img_show, (new_w, new_h))
for vis_mode in vis_modes:
out_file = osp.join(output_dir, 'vis_imgs', vis_mode,
f'{vis_mode}.jpg')
seg_model.show_result(img_show, img_tensor.to(device), result,
out_file, vis_mode)
out_file_dict[vis_mode] = out_file
return [out_file_dict[mode] for mode in vis_modes]
title = 'GroupViT'
description = """
Gradio Demo for GroupViT: Semantic Segmentation Emerges from Text Supervision. \n
You may click on of the examples or upload your own image. \n
GroupViT could perform open vocabulary segmentation, you may input more classes (seperate by comma).
"""
article = """
<p style='text-align: center'>
<a href='https://arxiv.org/abs/2202.11094' target='_blank'>
GroupViT: Semantic Segmentation Emerges from Text Supervision
</a>
|
<a href='https://github.com/NVlabs/GroupViT' target='_blank'>Github Repo</a></p>
"""
gr.Interface(
inference,
inputs=[
gr.inputs.Dropdown(dataset_options, type='value', label='Category list'),
gr.inputs.Textbox(
lines=1, placeholder=None, default='', label='More classes'),
gr.inputs.Image(type='filepath')
],
outputs=[gr.outputs.Image(label=label) for label in output_labels],
title=title,
description=description,
article=article,
examples=examples).launch(enable_queue=True)