Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import YolosImageProcessor, YolosForObjectDetection | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import random | |
# Constants | |
EXAMPLE_URL = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/city-streets.jpg' | |
THRESHOLD = 0.2 | |
# Load model and processor | |
def load_model(): | |
model_id = 'hustvl/yolos-tiny' | |
processor = YolosImageProcessor.from_pretrained(model_id) | |
model = YolosForObjectDetection.from_pretrained(model_id) | |
return processor, model | |
processor, model = load_model() | |
# Function to detect objects in the image | |
def detect(image): | |
# Preprocess image | |
inputs = processor(images=image, return_tensors="pt") | |
# Predict bounding boxes | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Extract bounding boxes and labels | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=THRESHOLD)[0] | |
return results | |
# Function to render bounding boxes | |
def render_box(image, results): | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(image) | |
ax = plt.gca() | |
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
if score < THRESHOLD: | |
continue | |
color = tuple([random.random() for _ in range(3)]) # Random color for each box | |
xmin, ymin, xmax, ymax = box | |
rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor=color, facecolor='none') | |
ax.add_patch(rect) | |
plt.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}", color=color, fontsize=12, bbox=dict(facecolor='white', alpha=0.5)) | |
plt.axis('off') | |
st.pyplot(plt) | |
# Streamlit app | |
st.title("Object Detection with Hugging Face Transformers") | |
uploaded_file = st.file_uploader("Choose an image...", type="jpg") | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
results = detect(image) | |
render_box(image, results) | |
else: | |
if st.button("Try Example Image"): | |
response = requests.get(EXAMPLE_URL) | |
image = Image.open(BytesIO(response.content)) | |
results = detect(image) | |
render_box(image, results) | |