Update app.py
Browse files
app.py
CHANGED
@@ -58,25 +58,25 @@ torch.manual_seed(0)
|
|
58 |
# lr=0.002,
|
59 |
# optimizer=torch.optim.Adam,
|
60 |
# device=device,
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
# Specify the path to the saved model weights
|
65 |
model_weights_path = 'model_weights.pth'
|
66 |
|
67 |
-
# Load the model weights
|
68 |
-
model.load_state_dict(torch.load(model_weights_path,map_location=torch.device('cpu')))
|
69 |
|
70 |
# Set the model to evaluation mode for inference
|
71 |
model.eval()
|
72 |
-
# Create a NeuralNetClassifier using the loaded model
|
73 |
-
cnn = NeuralNetClassifier(
|
74 |
-
module=model,
|
75 |
-
max_epochs=0, # Set max_epochs to 0 to avoid additional training
|
76 |
-
lr=0.002, # You can set this to the learning rate used during training
|
77 |
-
optimizer=torch.optim.Adam, # You can set the optimizer used during training
|
78 |
-
device='cpu' # You can specify the device ('cpu' for CPU, 'cuda' for GPU, etc.)
|
79 |
-
)
|
80 |
|
81 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
|
82 |
stroke_color = st.sidebar.color_picker("Stroke color hex: ")
|
@@ -103,16 +103,23 @@ canvas_result = st_canvas(
|
|
103 |
|
104 |
# Do something interesting with the image data and paths
|
105 |
if canvas_result.image_data is not None:
|
106 |
-
#st.image(canvas_result.image_data)
|
107 |
image = canvas_result.image_data
|
108 |
image1 = image.copy()
|
109 |
image1 = image1.astype('uint8')
|
110 |
-
image1 = cv2.cvtColor(image1,cv2.COLOR_BGR2GRAY)
|
111 |
-
image1 = cv2.resize(image1,(28,28))
|
112 |
st.image(image1)
|
113 |
|
114 |
-
image1
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
|
|
58 |
# lr=0.002,
|
59 |
# optimizer=torch.optim.Adam,
|
60 |
# device=device,
|
61 |
+
import streamlit as st
|
62 |
+
from st_canvas import st_canvas
|
63 |
+
import torch
|
64 |
+
from PIL import Image
|
65 |
+
import cv2
|
66 |
+
import numpy as np
|
67 |
+
from your_model_module import Cnn # Import your model architecture
|
68 |
+
|
69 |
+
# Create an instance of your model (Cnn model)
|
70 |
+
model = Cnn()
|
71 |
|
72 |
# Specify the path to the saved model weights
|
73 |
model_weights_path = 'model_weights.pth'
|
74 |
|
75 |
+
# Load the model weights onto a CPU device (if you want to use the CPU)
|
76 |
+
model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
|
77 |
|
78 |
# Set the model to evaluation mode for inference
|
79 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
|
82 |
stroke_color = st.sidebar.color_picker("Stroke color hex: ")
|
|
|
103 |
|
104 |
# Do something interesting with the image data and paths
|
105 |
if canvas_result.image_data is not None:
|
|
|
106 |
image = canvas_result.image_data
|
107 |
image1 = image.copy()
|
108 |
image1 = image1.astype('uint8')
|
109 |
+
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
|
110 |
+
image1 = cv2.resize(image1, (28, 28))
|
111 |
st.image(image1)
|
112 |
|
113 |
+
# Convert the image for prediction (assuming image1 is in the right format)
|
114 |
+
image1 = image1[np.newaxis, np.newaxis, ...] # Add batch and channel dimensions
|
115 |
+
|
116 |
+
# Perform prediction using the pre-trained model
|
117 |
+
with torch.no_grad():
|
118 |
+
tensor_image = torch.tensor(image1, dtype=torch.float32)
|
119 |
+
prediction = model(tensor_image)
|
120 |
+
|
121 |
+
# Display the predicted class
|
122 |
+
predicted_class = prediction.argmax().item()
|
123 |
+
st.title(f"Predicted Class: {predicted_class}")
|
124 |
+
|
125 |
|