SuperPoint / app.py
merve's picture
merve HF staff
Update app.py
bb8a0ae verified
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
import torch
import matplotlib.pyplot as plt
import uuid
import gradio as gr
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
def infer(image):
inputs = processor(image,return_tensors="pt").to(model.device, model.dtype)
model_outputs = model(**inputs)
image_sizes = [(image.size[1], image.size[0])]
outputs = processor.post_process_keypoint_detection(model_outputs, image_sizes)
keypoints = outputs[0]["keypoints"].detach().numpy()
scores = outputs[0]["scores"].detach().numpy()
image_width, image_height = image.size
plt.axis('off')
plt.imshow(image)
plt.scatter(
keypoints[:, 0],
keypoints[:, 1],
s=scores * 100,
c='cyan',
alpha=0.4
)
path = "./" + uuid.uuid4().hex + ".png"
plt.savefig(path)
plt.close()
return path
title = "SuperPoint"
description = "Try [SuperPoint](https://huggingface.co/docs/transformers/en/model_doc/superpoint) in this demo, foundation model for keypoint detection supported in 🤗 transformers. Simply upload an image or try the example. "
iface = gr.Interface(fn = infer, inputs = gr.Image(type="pil"),
outputs = gr.Image(), title=title, description=description, examples=["./bee.jpg"])
iface.launch()