salaz055's picture
Update app.py
afa3b98
raw
history blame contribute delete
941 Bytes
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import gradio as gr
import torch
import numpy as np
extractor = AutoFeatureExtractor.from_pretrained("salaz055/my_extractor_segmentation_model")
model = SegformerForSemanticSegmentation.from_pretrained("salaz055/my-segmentation-model")
def classify(im):
inputs = extractor(images=im, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
classes = logits[0].detach().cpu().numpy().argmax(axis=0)
colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
return colors[classes]
interface = gr.Interface(fn = classify,
inputs = gr.Image(type = 'pil'),
outputs = 'image',
title = "Image Segmentation",
description = "Use for semantic image segmentation. Finetuned on sidewalk-semantic")
interface.launch()