sfmig
added a different color palette
6c333c9
raw
history blame
4.33 kB
"""
Using as reference:
- https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
- https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
- https://huggingface.co/facebook/detr-resnet-50-panoptic
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_panoptic_segmentation_minimal_example_(with_DetrFeatureExtractor).ipynb
https://arxiv.org/abs/2005.12872
Additions
- add shown labels as strings
- show only animal masks (ask an nlp model?)
For next time
- for diff 'confidence' the high conf masks should change....
- colors are not great and should be constant per class? add text?
- Im getting core dumped (segmentation fault) when loading hugging face model.. :()
https://github.com/huggingface/transformers/issues/16939
- cap slider to 95?
"""
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import itertools
import seaborn as sns
def predict_animal_mask(im,
gr_slider_confidence):
image = Image.fromarray(im) # im: numpy array 3d: 480, 640, 3: to PIL Image
image = image.resize((200,200)) # PIL image # could I upsample output instead? better?
# encoding is a dict with pixel_values and pixel_mask
encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state']
logits = outputs.logits # torch.Size([1, 100, 251]); class logits? but why 251?
bboxes = outputs.pred_boxes
masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); mask logits? for every pixel, score in each of the 100 classes? there is a mask per class
# keep only the masks with high confidence?--------------------------------
# compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
prob_per_query = outputs.logits.softmax(-1)[..., :-1].max(-1)[0] # why logits last dim 251?
# threshold the confidence
keep = prob_per_query > gr_slider_confidence/100.0
# postprocess the mask (numpy arrays)
label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
color_mask = np.zeros(image.size+(3,))
palette = itertools.cycle(sns.color_palette())
for lbl in np.unique(label_per_pixel): #enumerate(palette()):
color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 #color
# color_mask = np.zeros(image.size+(3,))
# for lbl, color in enumerate(ade_palette()):
# color_mask[label_per_pixel==lbl,:] = color
# Show image + mask
pred_img = np.array(image.convert('RGB'))*0.5 + color_mask*0.5
pred_img = pred_img.astype(np.uint8)
return pred_img
#######################################
# get models from hugging face
feature_extractor = DetrFeatureExtractor.from_pretrained('facebook/detr-resnet-50-panoptic')
model = DetrForSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic')
# gradio components -inputs
gr_image_input = gr.inputs.Image()
gr_slider_confidence = gr.inputs.Slider(0,100,5,85,
label='Set confidence threshold for masks')
# gradio outputs
gr_image_output = gr.outputs.Image()
####################################################
# Create user interface and launch
gr.Interface(predict_animal_mask,
inputs = [gr_image_input,gr_slider_confidence],
outputs = gr_image_output,
title = 'Image segmentation with varying confidence',
description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
####################################
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)
# inputs = feature_extractor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)