Spaces:
Sleeping
Sleeping
cleaning up and themeing
Browse files- .streamlit/config.toml +6 -0
- app.py +49 -25
.streamlit/config.toml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
primaryColor = "F36295"
|
3 |
+
backgroundColor = "#FFF"
|
4 |
+
secondaryBackgroundColor = "#3183D1"
|
5 |
+
textColor = "#000"
|
6 |
+
font = "sans-serif"
|
app.py
CHANGED
@@ -20,49 +20,66 @@ from scipy.special import softmax
|
|
20 |
@st.cache_resource
|
21 |
def load_picture():
|
22 |
"""
|
23 |
-
|
24 |
-
to be displayed in streamlit.
|
25 |
"""
|
26 |
-
# load the mnist dataset
|
27 |
-
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
28 |
-
# plot the first 9 images
|
29 |
-
for i in range(9):
|
30 |
-
plt.subplot(330 + 1 + i)
|
31 |
-
image = x_train[i] / 255.0
|
32 |
-
plt.imshow(image, cmap=plt.get_cmap("gray"))
|
33 |
-
|
34 |
-
# Save the plot as a png file and show it in streamlit
|
35 |
-
# This is commented out for not because the plot was created and saved in the img directory during the initial run of the app locally
|
36 |
-
# plt.savefig("img/show.png")
|
37 |
st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
|
38 |
|
39 |
|
40 |
def keras_prediction(final, model_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
load_time = time.time()
|
42 |
model = models.load_model(
|
43 |
os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
|
44 |
)
|
45 |
after_load_curr = time.time()
|
|
|
|
|
46 |
curr_time = time.time()
|
47 |
prediction = model.predict(final[None, ...])
|
48 |
after_time = time.time()
|
|
|
49 |
return prediction, after_time - curr_time, after_load_curr - load_time
|
50 |
|
51 |
|
52 |
def onnx_prediction(final, model_path):
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
im_np = im_np.astype("float32")
|
|
|
|
|
56 |
load_curr = time.time()
|
57 |
session = onnxruntime.InferenceSession(model_path, None)
|
58 |
input_name = session.get_inputs()[0].name
|
59 |
output_name = session.get_outputs()[0].name
|
60 |
after_load_curr = time.time()
|
61 |
|
|
|
62 |
curr_time = time.time()
|
63 |
result = session.run([output_name], {input_name: im_np})
|
64 |
prediction = softmax(np.array(result).squeeze(), axis=0)
|
65 |
after_time = time.time()
|
|
|
66 |
return prediction, after_time - curr_time, after_load_curr - load_curr
|
67 |
|
68 |
|
@@ -70,21 +87,24 @@ def main():
|
|
70 |
"""
|
71 |
The main function/primary entry point of the app
|
72 |
"""
|
73 |
-
#
|
|
|
74 |
st.title("MNIST Digit Recognizer")
|
75 |
|
76 |
col1, col2 = st.columns([0.8, 0.2], gap="small")
|
77 |
with col1:
|
78 |
st.markdown(
|
79 |
"""
|
80 |
-
This Streamlit app
|
81 |
- Change the stroke width of the digit using the slider
|
82 |
- Choose what model you use for predictions
|
83 |
- Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
|
84 |
- Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
|
85 |
-
- Basic: A simple
|
|
|
|
|
86 |
|
87 |
-
|
88 |
unsafe_allow_html=True,
|
89 |
)
|
90 |
with col2:
|
@@ -118,8 +138,8 @@ def main():
|
|
118 |
background_color="#000",
|
119 |
background_image=None,
|
120 |
update_streamlit=True,
|
121 |
-
height=
|
122 |
-
width=
|
123 |
drawing_mode="freedraw",
|
124 |
point_display_radius=0,
|
125 |
key="canvas",
|
@@ -144,10 +164,14 @@ def main():
|
|
144 |
prediction, pred_time, load_time = onnx_prediction(final, model_path)
|
145 |
|
146 |
# print the prediction
|
147 |
-
st.header(f"
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
# Create a 2 column dataframe with one column as the digits and the other as the probability
|
153 |
data = pd.DataFrame(
|
|
|
20 |
@st.cache_resource
|
21 |
def load_picture():
|
22 |
"""
|
23 |
+
Shows the MNIST dataset image
|
|
|
24 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
st.image("img/show.png", width=250, caption="First 9 images from the MNIST dataset")
|
26 |
|
27 |
|
28 |
def keras_prediction(final, model_path):
|
29 |
+
"""Make a predition using a Keras model
|
30 |
+
|
31 |
+
Args:
|
32 |
+
final: The input image
|
33 |
+
model_path: The path of the Keras model to load
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
np.array: Predictions from the model. The probability of each digit.
|
37 |
+
float: Time to make the prediction
|
38 |
+
float: Time to load the model
|
39 |
+
"""
|
40 |
+
# load the model
|
41 |
load_time = time.time()
|
42 |
model = models.load_model(
|
43 |
os.path.abspath(os.path.join(os.path.dirname(__file__), model_path))
|
44 |
)
|
45 |
after_load_curr = time.time()
|
46 |
+
|
47 |
+
# Make the prediction
|
48 |
curr_time = time.time()
|
49 |
prediction = model.predict(final[None, ...])
|
50 |
after_time = time.time()
|
51 |
+
|
52 |
return prediction, after_time - curr_time, after_load_curr - load_time
|
53 |
|
54 |
|
55 |
def onnx_prediction(final, model_path):
|
56 |
+
"""Make a predition using an Onnx model
|
57 |
+
Args:
|
58 |
+
final: The input image
|
59 |
+
model_path: The path of the Onnx model to load
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
np.array: Predictions from the model. The probability of each digit.
|
63 |
+
float: Time to make the prediction
|
64 |
+
float: Time to load the model
|
65 |
+
"""
|
66 |
+
im_np = np.expand_dims(final, axis=0)
|
67 |
+
im_np = np.expand_dims(im_np, axis=0)
|
68 |
im_np = im_np.astype("float32")
|
69 |
+
|
70 |
+
# Load the model
|
71 |
load_curr = time.time()
|
72 |
session = onnxruntime.InferenceSession(model_path, None)
|
73 |
input_name = session.get_inputs()[0].name
|
74 |
output_name = session.get_outputs()[0].name
|
75 |
after_load_curr = time.time()
|
76 |
|
77 |
+
# Make the prediction
|
78 |
curr_time = time.time()
|
79 |
result = session.run([output_name], {input_name: im_np})
|
80 |
prediction = softmax(np.array(result).squeeze(), axis=0)
|
81 |
after_time = time.time()
|
82 |
+
|
83 |
return prediction, after_time - curr_time, after_load_curr - load_curr
|
84 |
|
85 |
|
|
|
87 |
"""
|
88 |
The main function/primary entry point of the app
|
89 |
"""
|
90 |
+
# Setup
|
91 |
+
st.set_page_config(layout="wide")
|
92 |
st.title("MNIST Digit Recognizer")
|
93 |
|
94 |
col1, col2 = st.columns([0.8, 0.2], gap="small")
|
95 |
with col1:
|
96 |
st.markdown(
|
97 |
"""
|
98 |
+
This Streamlit app demonstrates the performance of multiple different neural networks (and associated frameworks) trained on the <a href="https://yann.lecun.com/exdb/mnist/">MNIST dataset</a> to predict handwritten digits. Draw a digit in the canvas below and see the model's prediction. You can:
|
99 |
- Change the stroke width of the digit using the slider
|
100 |
- Choose what model you use for predictions
|
101 |
- Onnx: The mnist-12 Onnx model from <a href="https://xethub.com/XetHub/onnx-models/src/branch/main/vision/classification/mnist">Onnx's pre-trained MNIST models</a>
|
102 |
- Autokeras: A model generated using the <a href="https://autokeras.com/image_classifier/">Autokeras image classifier class</a>
|
103 |
+
- Basic: A simple <a href="https://keras.io/">Keras</a> model with two layers where each layer has 300 nodes. The model was trained on the MNIST dataset for 35 epochs.
|
104 |
+
|
105 |
+
Like any machine learning model, this model is a function of the data it was fed during training. As you can see in the picture, the numbers in the images have a specific shape, location, and size. By playing around with the stroke width and where you draw the digit, you can see how the model's prediction changes.
|
106 |
|
107 |
+
If you change your selected model after drawing the digit, that same drawing will be used with the newly selected model. To clear your "hand" drawn digit, click the trashcan icon under the drawing canvas.""",
|
108 |
unsafe_allow_html=True,
|
109 |
)
|
110 |
with col2:
|
|
|
138 |
background_color="#000",
|
139 |
background_image=None,
|
140 |
update_streamlit=True,
|
141 |
+
height=300,
|
142 |
+
width=300,
|
143 |
drawing_mode="freedraw",
|
144 |
point_display_radius=0,
|
145 |
key="canvas",
|
|
|
164 |
prediction, pred_time, load_time = onnx_prediction(final, model_path)
|
165 |
|
166 |
# print the prediction
|
167 |
+
st.header(f"Results")
|
168 |
+
table_data = {
|
169 |
+
"Model": [model_choice],
|
170 |
+
"Prediction": [np.argmax(prediction)],
|
171 |
+
"Load time (ms)": f"{load_time * 1000:.2f}",
|
172 |
+
"Prediction time (ms)": f"{pred_time * 1000:.2f}",
|
173 |
+
}
|
174 |
+
st.table(table_data)
|
175 |
|
176 |
# Create a 2 column dataframe with one column as the digits and the other as the probability
|
177 |
data = pd.DataFrame(
|