Spaces:
Runtime error
Runtime error
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation | |
from torch import nn | |
import numpy as np | |
from palette import ade_palette | |
import gradio as gr | |
def seg(image): | |
## first resize the image !! | |
image.resize((200,200)) | |
feature_extractor = AutoFeatureExtractor.from_pretrained("nvidia/mit-b0") | |
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0") | |
print(model) | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
## The model outputs logits of shape (batch_size, num_labels, height/4, width/4). | |
# We first rescale the logits to match the original size of the image using | |
# "bilinear interpolation". Next, we perform an argmax on the class dimension, | |
# and we create a color map which we draw over the image. | |
# First, rescale logits to original image size | |
logits = nn.functional.interpolate(outputs.logits.detach().cpu(), | |
size=image.size[::-1], # (height, width) | |
mode='bilinear', | |
align_corners=False) | |
# Second, apply argmax on the class dimension | |
seg = logits.argmax(dim=1)[0] | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
palette = np.array(ade_palette()) | |
for label, color in enumerate(palette): | |
color_seg[seg == label, :] = color | |
# Convert to BGR | |
color_seg = color_seg[..., ::-1] | |
img = np.array(image) * 0.5 + color_seg * 0.5 | |
img = img.astype(np.uint8) | |
return img | |
iface = gr.Interface(fn=seg, inputs=gr.inputs.Image(type='pil'), outputs=gr.outputs.Image()) | |
iface.launch() |