yolov5-ui / app.py
robmarkcole's picture
Upload app.py
5885827
import streamlit as st
import torch
from PIL import Image, ImageDraw
from typing import Tuple
import numpy as np
import const
import time
def draw_box(
draw: ImageDraw,
box: Tuple[float, float, float, float],
text: str = "",
color: Tuple[int, int, int] = (255, 255, 0),
) -> None:
"""
Draw a bounding box on and image.
"""
line_width = 3
font_height = 8
y_min, x_min, y_max, x_max = box
(left, right, top, bottom) = (
x_min,
x_max,
y_min,
y_max,
)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
width=line_width,
fill=color,
)
if text:
draw.text(
(left + line_width, abs(top - line_width - font_height)), text, fill=color
)
@st.cache(allow_output_mutation=True, show_spinner=True)
def get_model(model_id : str = "yolov5s"):
model = torch.hub.load("ultralytics/yolov5", model_id)
return model
# Settings
st.sidebar.title("Settings")
model_id = st.sidebar.selectbox("Pretrained model", const.PRETRAINED_MODELS, index=1)
img_size = st.sidebar.selectbox("Image resize for inference", const.IMAGE_SIZES, index=1)
CONFIDENCE = st.sidebar.slider(
"Confidence threshold",
const.MIN_CONF,
const.MAX_CONF,
const.DEFAULT_CONF,
)
model = get_model(model_id)
st.title(f"{model_id}")
img_file_buffer = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if img_file_buffer is not None:
pil_image = Image.open(img_file_buffer)
else:
pil_image = Image.open(const.DEFAULT_IMAGE)
st.text(f"Input image width and height: {pil_image.width} x {pil_image.height}")
start_time = time.time()
results = model(pil_image, size=img_size)
end_time = time.time()
df = results.pandas().xyxy[0]
df = df[df["confidence"] > CONFIDENCE]
draw = ImageDraw.Draw(pil_image)
for _, obj in df.iterrows():
name = obj["name"]
confidence = obj["confidence"]
box_label = f"{name}"
draw_box(
draw,
(obj["ymin"], obj["xmin"], obj["ymax"], obj["xmax"]),
text=box_label,
color=const.RED,
)
st.image(
np.array(pil_image),
caption=f"Processed image",
use_column_width=True,
)
st.text(f"Time to inference: {round(time.time() - end_time, 2)} sec")
st.table(df)