import numpy as np import gradio as gr from PIL import Image import torch from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation model_checkpoint = "apple/deeplabv3-mobilevit-small" feature_extractor = MobileViTFeatureExtractor.from_pretrained(model_checkpoint) model = MobileViTForSemanticSegmentation.from_pretrained(model_checkpoint).eval() palette = np.array( [ [ 0, 0, 0], [192, 0, 0], [ 0, 192, 0], [192, 192, 0], [ 0, 0, 192], [192, 0, 192], [ 0, 192, 192], [192, 192, 192], [128, 0, 0], [255, 0, 0], [128, 192, 0], [255, 192, 0], [128, 0, 192], [255, 0, 192], [128, 192, 192], [255, 192, 192], [ 0, 128, 0], [192, 128, 0], [ 0, 255, 0], [192, 255, 0], [ 0, 128, 192] ], dtype=np.uint8) labels = [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ] # Draw the labels. Light colors use black text, dark colors use white text. inverted = [ 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20 ] labels_colored = [] for i in range(len(labels)): r, g, b = palette[i] label = labels[i] color = "white" if i in inverted else "black" text = "%s" % (r, g, b, color, label) labels_colored.append(text) labels_text = ", ".join(labels_colored) title = "Semantic Segmentation with MobileViT and DeepLabV3" description = """ The input image is resized and center cropped to 512×512 pixels. The segmentation output is 32×32 pixels.
This model has been trained on Pascal VOC. The classes are: """ + labels_text + "

" article = """

Sources:

📜 MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer

🏋️ Original pretrained weights from this GitHub repo

🏙 Example images from this dataset

""" examples = [ ["cat-3.jpg"], ["construction-site.jpg"], ["dog-cat.jpg"], ["football-match.jpg"], ] def predict(image): with torch.no_grad(): inputs = feature_extractor(image, return_tensors="pt") outputs = model(**inputs) # Get preprocessed image. The pixel values don't need to be unnormalized # for this particular model. resized = (inputs["pixel_values"].numpy().squeeze().transpose(1, 2, 0)[..., ::-1] * 255).astype(np.uint8) # Class predictions for each pixel. classes = outputs.logits.argmax(1).squeeze().numpy().astype(np.uint8) # Super slow method but it works... should probably improve this. colored = np.zeros((classes.shape[0], classes.shape[1], 3), dtype=np.uint8) for y in range(classes.shape[0]): for x in range(classes.shape[1]): colored[y, x] = palette[classes[y, x]] # Resize predictions to input size (not original size). colored = Image.fromarray(colored) colored = colored.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST) # Keep everything that is not background. mask = (classes != 0) * 255 mask = Image.fromarray(mask.astype(np.uint8)).convert("RGB") mask = mask.resize((resized.shape[1], resized.shape[0]), resample=Image.Resampling.NEAREST) # Blend with the input image. resized = Image.fromarray(resized) highlighted = Image.blend(resized, mask, 0.4) #colored = colored.resize((256, 256), resample=Image.Resampling.BICUBIC) #highlighted = highlighted.resize((256, 256), resample=Image.Resampling.BICUBIC) return colored, highlighted gr.Interface( fn=predict, inputs=gr.inputs.Image(label="Upload image"), outputs=[gr.outputs.Image(label="Classes"), gr.outputs.Image(label="Overlay")], title=title, description=description, article=article, examples=examples, ).launch()