import gradio as gr import spaces import torch import numpy as np from PIL import Image from transformers import pipeline import matplotlib.pyplot as plt import io model_pipeline = pipeline("image-segmentation", model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned") id2label = {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'} sidewalk_palette = [ [0, 0, 0], # unlabeled [216, 82, 24], # flat-road [255, 255, 0], # flat-sidewalk [125, 46, 141], # flat-crosswalk [118, 171, 47], # flat-cyclinglane [161, 19, 46], # flat-parkingdriveway [255, 0, 0], # flat-railtrack [0, 128, 128], # flat-curb [190, 190, 0], # human-person [0, 255, 0], # human-rider [0, 0, 255], # vehicle-car [170, 0, 255], # vehicle-truck [84, 84, 0], # vehicle-bus [84, 170, 0], # vehicle-tramtrain [84, 255, 0], # vehicle-motorcycle [170, 84, 0], # vehicle-bicycle [170, 170, 0], # vehicle-caravan [170, 255, 0], # vehicle-cartrailer [255, 84, 0], # construction-building [255, 170, 0], # construction-door [255, 255, 0], # construction-wall [33, 138, 200], # construction-fenceguardrail [0, 170, 127], # construction-bridge [0, 255, 127], # construction-tunnel [84, 0, 127], # construction-stairs [84, 84, 127], # object-pole [84, 170, 127], # object-trafficsign [84, 255, 127], # object-trafficlight [170, 0, 127], # nature-vegetation [170, 84, 127], # nature-terrain [170, 170, 127], # sky [170, 255, 127], # void-ground [255, 0, 127], # void-dynamic [255, 84, 127], # void-static [255, 170, 127], # void-unclear ] def get_output_figure(pil_img, results): plt.figure(figsize=(16, 10)) plt.imshow(pil_img) image_array = np.array(pil_img) segmentation_map = np.zeros_like(image_array) for result in results: mask = np.array(result['mask']) label = result['label'] label_index = list(id2label.values()).index(label) color = sidewalk_palette[label_index] for c in range(3): segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c]) plt.imshow(segmentation_map, alpha=0.5) plt.axis('off') return plt.gcf() @spaces.GPU def detect(image): results = model_pipeline(image) print(results) output_figure = get_output_figure(image, results) buf = io.BytesIO() output_figure.savefig(buf, bbox_inches='tight') buf.seek(0) output_pil_img = Image.open(buf) return output_pil_img with gr.Blocks() as demo: gr.Markdown("# Semantic segmentation with SegFormer fine tuned on segments/sidewalk") gr.Markdown( """ This application uses a fine tuned SegFormer for sematic segmenation over an input image. This version was trained using segments/sidewalk dataset. You can load an image and see the predicted segmentation. """ ) gr.Interface( fn=detect, inputs=gr.Image(label="Input image", type="pil"), outputs=[ gr.Image(label="Output prediction", type="pil") ] ) demo.launch(show_error=True)