|
import streamlit as st |
|
from PIL import Image |
|
import numpy as np |
|
from ultralytics import YOLO |
|
from io import BytesIO |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
""" |
|
Load and cache the model |
|
""" |
|
model = YOLO("weights.pt") |
|
return model |
|
|
|
def predict(model, image, font_size, line_width): |
|
""" |
|
Run inference and return annotated image |
|
""" |
|
results = model.predict(image) |
|
r = results[0] |
|
im_bgr = r.plot(conf=False, pil=True, font_size=font_size, line_width=line_width) |
|
im_rgb = Image.fromarray(im_bgr[..., ::-1]) |
|
return im_rgb |
|
|
|
def file_uploader_cb(uploaded_file, font_size, line_width): |
|
image = Image.open(uploaded_file).convert("RGB") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
|
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
annotated_img = predict(model, image, font_size, line_width) |
|
with col2: |
|
|
|
st.image(annotated_img, caption='Prediction', use_column_width=True) |
|
|
|
imbuffer = BytesIO() |
|
annotated_img.save(imbuffer, format="JPEG") |
|
st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="upload") |
|
|
|
def image_capture_cb(capture, font_size, line_width, col): |
|
image = Image.open(capture).convert("RGB") |
|
|
|
annotated_img = predict(model, image, font_size, line_width) |
|
with col: |
|
|
|
st.image(annotated_img, caption='Prediction', use_column_width=True) |
|
|
|
imbuffer = BytesIO() |
|
annotated_img.save(imbuffer, format="JPEG") |
|
st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="capture") |
|
|
|
if __name__ == "__main__": |
|
|
|
st.set_page_config( |
|
page_title="Circuit Sketch Recognizer", |
|
layout="wide" |
|
) |
|
st.title("Circuit Sketch Recognition") |
|
with st.sidebar: |
|
font_size = st.slider(label="Font Size", min_value=6, max_value=64, step=1, value=24) |
|
line_width = st.slider(label="Bounding Box Line Thickness", min_value=1, max_value=8, step=1, value=3) |
|
|
|
model = load_model() |
|
|
|
|
|
tabs = st.tabs(["Capture Picture", "Upload Your Image", "Show Examples"]) |
|
with tabs[0]: |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
file_uploader_cb(uploaded_file, font_size, line_width) |
|
with tabs[1]: |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
capture = st.camera_input("Take a picture with Camera") |
|
if capture is not None: |
|
image_capture_cb(capture, font_size, line_width, col2) |
|
with tabs[2]: |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image('example1.jpg', use_column_width=True, caption='Example 1') |
|
with col2: |
|
st.image('example2.jpg', use_column_width=True, caption='Example 2') |
|
|