Toolbox / app.py
ysalaun's picture
initial commit
a36cede
raw
history blame
7.5 kB
import os
# gradio for visual demo
import gradio as gr
# transformers for easy access to nnet
os.system("pip install git+https://github.com/huggingface/transformers.git")
os.system("pip install datasets")
os.system("pip install scipy")
os.system("pip install torch")
from transformers import AutoImageProcessor, AutoModelForImageClassification, DPTForDepthEstimation, Mask2FormerForUniversalSegmentation
import torch
import numpy as np
from PIL import Image
import requests
from collections import defaultdict
palette = np.asarray([
[0, 0, 0],
[120, 120, 120],
[180, 120, 120],
[6, 230, 230],
[80, 50, 50],
[4, 200, 3],
[120, 120, 80],
[140, 140, 140],
[204, 5, 255],
[230, 230, 230],
[4, 250, 7],
[224, 5, 255],
[235, 255, 7],
[150, 5, 61],
[120, 120, 70],
[8, 255, 51],
[255, 6, 82],
[143, 255, 140],
[204, 255, 4],
[255, 51, 7],
[204, 70, 3],
[0, 102, 200],
[61, 230, 250],
[255, 6, 51],
[11, 102, 255],
[255, 7, 71],
[255, 9, 224],
[9, 7, 230],
[220, 220, 220],
[255, 9, 92],
[112, 9, 255],
[8, 255, 214],
[7, 255, 224],
[255, 184, 6],
[10, 255, 71],
[255, 41, 10],
[7, 255, 255],
[224, 255, 8],
[102, 8, 255],
[255, 61, 6],
[255, 194, 7],
[255, 122, 8],
[0, 255, 20],
[255, 8, 41],
[255, 5, 153],
[6, 51, 255],
[235, 12, 255],
[160, 150, 20],
[0, 163, 255],
[140, 140, 140],
[250, 10, 15],
[20, 255, 0],
[31, 255, 0],
[255, 31, 0],
[255, 224, 0],
[153, 255, 0],
[0, 0, 255],
[255, 71, 0],
[0, 235, 255],
[0, 173, 255],
[31, 0, 255],
[11, 200, 200],
[255, 82, 0],
[0, 255, 245],
[0, 61, 255],
[0, 255, 112],
[0, 255, 133],
[255, 0, 0],
[255, 163, 0],
[255, 102, 0],
[194, 255, 0],
[0, 143, 255],
[51, 255, 0],
[0, 82, 255],
[0, 255, 41],
[0, 255, 173],
[10, 0, 255],
[173, 255, 0],
[0, 255, 153],
[255, 92, 0],
[255, 0, 255],
[255, 0, 245],
[255, 0, 102],
[255, 173, 0],
[255, 0, 20],
[255, 184, 184],
[0, 31, 255],
[0, 255, 61],
[0, 71, 255],
[255, 0, 204],
[0, 255, 194],
[0, 255, 82],
[0, 10, 255],
[0, 112, 255],
[51, 0, 255],
[0, 194, 255],
[0, 122, 255],
[0, 255, 163],
[255, 153, 0],
[0, 255, 10],
[255, 112, 0],
[143, 255, 0],
[82, 0, 255],
[163, 255, 0],
[255, 235, 0],
[8, 184, 170],
[133, 0, 255],
[0, 255, 92],
[184, 0, 255],
[255, 0, 31],
[0, 184, 255],
[0, 214, 255],
[255, 0, 112],
[92, 255, 0],
[0, 224, 255],
[112, 224, 255],
[70, 184, 160],
[163, 0, 255],
[153, 0, 255],
[71, 255, 0],
[255, 0, 163],
[255, 204, 0],
[255, 0, 143],
[0, 255, 235],
[133, 255, 0],
[255, 0, 235],
[245, 0, 255],
[255, 0, 122],
[255, 245, 0],
[10, 190, 212],
[214, 255, 0],
[0, 204, 255],
[20, 0, 255],
[255, 255, 0],
[0, 153, 255],
[0, 41, 255],
[0, 255, 204],
[41, 0, 255],
[41, 255, 0],
[173, 0, 255],
[0, 245, 255],
[71, 0, 255],
[122, 0, 255],
[0, 255, 184],
[0, 92, 255],
[184, 255, 0],
[0, 133, 255],
[255, 214, 0],
[25, 194, 194],
[102, 255, 0],
[92, 0, 255],
])
depth_image_processor = AutoImageProcessor.from_pretrained("facebook/dpt-dinov2-small-nyu")
depth_model = DPTForDepthEstimation.from_pretrained("facebook/dpt-dinov2-small-nyu")
def compute_depth(img):
# prepare image for the model
inputs = depth_image_processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
# interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=img.size[::-1],
mode="bicubic",
align_corners=False,
)
# visualize the prediction
output = prediction.squeeze().cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
return [depth, "depth"]
clas_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small-imagenet1k-1-layer')
clas_model = AutoModelForImageClassification.from_pretrained('facebook/dinov2-small-imagenet1k-1-layer')
def compute_clas(img):
inputs = clas_processor(images=img, return_tensors="pt")
outputs = clas_model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return[img, clas_model.config.id2label[predicted_class_idx]]
m2f_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
m2f_model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-base-coco-panoptic")
def seg2sem(seg):
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
handles = []
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
if (seg == label).count_nonzero() > 0:
handles.append(m2f_model.config.id2label[label])
handles.append(color)
color_seg = color_seg.astype(np.uint8)
image = Image.fromarray(color_seg)
return [image,handles]
def seg2pano(seg, segments_info):
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
handles = []
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
color_seg = color_seg.astype(np.uint8)
image = Image.fromarray(color_seg)
instances_counter = defaultdict(int)
handles = []
for segment in segments_info:
segment_id = segment['id']
segment_label_id = segment['label_id']
segment_label = m2f_model.config.id2label[segment_label_id]
label = f"{segment_label}-{instances_counter[segment_label_id]}"
instances_counter[segment_label_id] += 1
color = palette[segment_id]
handles.append(label)
handles.append(color)
return [image,handles]
def compute_m2f_sem_seg(img):
inputs = m2f_processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = m2f_model(**inputs)
seg = m2f_processor.post_process_semantic_segmentation(
outputs, target_sizes=[img.size[::-1]]
)[0]
return seg2sem(seg)
def compute_m2f_pano_seg(img):
inputs = m2f_processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = m2f_model(**inputs)
seg = m2f_processor.post_process_panoptic_segmentation(
outputs, target_sizes=[img.size[::-1]]
)[0]
return seg2pano(seg["segmentation"], seg["segments_info"])
labels = ["Dinov2 - Depth", "Dinov2 - Classification", "M2F - Semantic Segmentation", "M2F - Panoptic Segmentation"]
# main function
def detect(img, application):
if application == labels[0]:
return compute_depth(img)
elif application == labels[1]:
return compute_clas(img)
elif application == labels[2]:
return compute_m2f_sem_seg(img)
elif application == labels[3]:
return compute_m2f_pano_seg(img)
return img
# visual gradio interface
iface = gr.Interface(fn=detect, inputs=[gr.Image(type="pil"), gr.Radio(labels, label="Application")], outputs=[gr.Image(type="pil"), gr.Textbox()])
iface.launch(debug=True)