shyamgupta196
segformer extracter
3dc8b10
raw
history blame contribute delete
No virus
1.7 kB
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from palette import ade_palette
from PIL import Image
import gradio as gr
def seg(image):
feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/mit-b5")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b5")
print(model)
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
## The model outputs logits of shape (batch_size, num_labels, height/4, width/4).
# We first rescale the logits to match the original size of the image using
# "bilinear interpolation". Next, we perform an argmax on the class dimension,
# and we create a color map which we draw over the image.
# First, rescale logits to original image size
logits = nn.functional.interpolate(outputs.logits.detach().cpu(),
size=image.size[::-1], # (height, width)
mode='bilinear',
align_corners=False)
# Second, apply argmax on the class dimension
seg = logits.argmax(dim=1)[0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
palette = np.array(ade_palette())
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# Convert to BGR
color_seg = color_seg[..., ::-1]
img = np.array(image) * 0.5 + color_seg * 0.5
img = img.astype(np.uint8)
img = Image.fromarray(img)
return img
iface = gr.Interface(fn=seg, inputs=gr.inputs.Image(type='pil'), outputs=gr.outputs.Image('pil'))
iface.launch()