# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py import argparse import glob import multiprocessing as mp import os os.environ["CUDA_VISIBLE_DEVICES"] = "" os.system('pip install git+https://github.com/facebookresearch/detectron2.git') # fmt: off import sys sys.path.insert(1, os.path.join(sys.path[0], '..')) # fmt: on import tempfile import time import warnings import cv2 import numpy as np import tqdm from detectron2.config import get_cfg from detectron2.data.detection_utils import read_image from detectron2.projects.deeplab import add_deeplab_config from detectron2.utils.logger import setup_logger from cat_seg import add_cat_seg_config from demo.predictor import VisualizationDemo import gradio as gr from matplotlib.backends.backend_agg import FigureCanvasAgg as fc # constants WINDOW_NAME = "MaskFormer demo" def setup_cfg(args): # load config from file and command-line arguments cfg = get_cfg() add_deeplab_config(cfg) add_cat_seg_config(cfg) cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.MODEL.DEVICE = "cpu" cfg.freeze() return cfg def get_parser(): parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") parser.add_argument( "--config-file", default="configs/vitl_swinb_384.yaml", metavar="FILE", help="path to config file", ) parser.add_argument( "--input", nargs="+", help="A list of space separated input images; " "or a single glob pattern such as 'directory/*.jpg'", ) parser.add_argument( "--opts", help="Modify config options using the command-line 'KEY VALUE' pairs", default=["MODEL.WEIGHTS", "model_final.pth", "MODEL.SEM_SEG_HEAD.TRAIN_CLASS_JSON", "datasets/voc20.json", "MODEL.SEM_SEG_HEAD.TEST_CLASS_JSON", "datasets/voc20.json", "TEST.SLIDING_WINDOW", "True", "MODEL.SEM_SEG_HEAD.POOLING_SIZES", "[1,1]"], nargs=argparse.REMAINDER, ) return parser def save_masks(preds, text): preds = preds['sem_seg'].argmax(dim=0).cpu().numpy() # C H W for i, t in enumerate(text): dir = f"masks/mask_{t}.png" mask = preds == i cv2.imwrite(dir, mask * 255) def predict(image, text): args = get_parser().parse_args() cfg = setup_cfg(args) demo = VisualizationDemo(cfg, text=text) predictions, visualized_output = demo.run_on_image(image) save_masks(predictions, text.split(',')) canvas = fc(visualized_output.fig) canvas.draw() out = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(canvas.get_width_height()[::-1] + (3,)) return out[..., ::-1] if __name__ == "__main__": args = get_parser().parse_args() cfg = setup_cfg(args) iface = gr.Interface( fn=predict, inputs=[gr.Image(), gr.Textbox(placeholder="Classes to segment")], outputs="image", ) iface.launch()