|
import gradio as gr |
|
import random |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
from transformers import SegformerForSemanticSegmentation) |
|
|
|
|
|
MODEL_PATH="./best_model_test/" |
|
|
|
device = torch.device("cpu") |
|
|
|
preprocessor = transforms.Compose([ |
|
transforms.resize(128), |
|
transforms.ToTensor() |
|
]) |
|
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH) |
|
model.eval() |
|
|
|
|
|
def upscale_logits(logit_outputs, size): |
|
"""Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input""" |
|
return nn.functional.interpolate( |
|
logit_outputs, |
|
size=size, |
|
mode="bilinear", |
|
align_corners=False |
|
) |
|
|
|
|
|
def visualize_instance_seg_mask(mask): |
|
"""Agrega colores RGB a cada una de las clases en la mask""" |
|
image = np.zeros((mask.shape[0], mask.shape[1], 3)) |
|
labels = np.unique(mask) |
|
label2color = {label: (random.randint(0, 1), |
|
random.randint(0, 255), |
|
random.randint(0, 255)) for label in labels} |
|
for i in range(image.shape[0]): |
|
for j in range(image.shape[1]): |
|
image[i, j, :] = label2color[mask[i, j]] |
|
image = image / 255 |
|
return image |
|
|
|
|
|
def query_image(img): |
|
"""Función para generar predicciones a la escala origina""" |
|
inputs = preprocessor(images=img, return_tensors="pt") |
|
inputs = preprocessor(img).unsqueeze(0) |
|
with torch.no_grad(): |
|
preds = model(inputs)["logits"] |
|
preds_upscale = upscale_logits(preds, preds.shape[2]) |
|
predict_label = torch.argmax(preds_upscale, dim=1).to(device) |
|
result = predict_label[0,:,:].detach().cpu().numpy() |
|
return visualize_instance_seg_mask(result) |
|
|
|
|
|
demo = gr.Interface( |
|
query_image, |
|
inputs=[gr.Image(type="pil")], |
|
outputs="image", |
|
title="Skyguard: segmentador de glaciares de roca 🛰️ +️ 🛡️ ️", |
|
description="Modelo de segmentación de imágenes para detectar glaciares de roca.<br> Se entrenó un modelo [nvidia/SegFormer](https://huggingface.co/nvidia/mit-b0) con _fine-tuning_ en el [rock-glacier-dataset](https://huggingface.co/datasets/alkzar90/rock-glacier-dataset)" |
|
) |
|
|
|
demo.launch() |
|
|