Update app.py
Browse files
app.py
CHANGED
@@ -63,19 +63,63 @@ import torch
|
|
63 |
from PIL import Image
|
64 |
import cv2
|
65 |
import numpy as np
|
66 |
-
from your_model_module import Cnn # Import your model architecture
|
67 |
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Specify the path to the saved model weights
|
72 |
model_weights_path = 'model_weights.pth'
|
73 |
|
74 |
-
# Load the model weights
|
75 |
-
model.load_state_dict(torch.load(model_weights_path,
|
76 |
|
77 |
# Set the model to evaluation mode for inference
|
78 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
|
81 |
stroke_color = st.sidebar.color_picker("Stroke color hex: ")
|
@@ -102,23 +146,16 @@ canvas_result = st_canvas(
|
|
102 |
|
103 |
# Do something interesting with the image data and paths
|
104 |
if canvas_result.image_data is not None:
|
|
|
105 |
image = canvas_result.image_data
|
106 |
image1 = image.copy()
|
107 |
image1 = image1.astype('uint8')
|
108 |
-
image1 = cv2.cvtColor(image1,
|
109 |
-
image1 = cv2.resize(image1,
|
110 |
st.image(image1)
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
tensor_image = torch.tensor(image1, dtype=torch.float32)
|
118 |
-
prediction = model(tensor_image)
|
119 |
-
|
120 |
-
# Display the predicted class
|
121 |
-
predicted_class = prediction.argmax().item()
|
122 |
-
st.title(f"Predicted Class: {predicted_class}")
|
123 |
-
|
124 |
-
|
|
|
63 |
from PIL import Image
|
64 |
import cv2
|
65 |
import numpy as np
|
|
|
66 |
|
67 |
+
|
68 |
+
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
+
XCnn = X.reshape(-1, 1, 28, 28)
|
70 |
+
XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)
|
71 |
+
|
72 |
+
from PIL import Image
|
73 |
+
import torchvision.transforms as transforms
|
74 |
+
class Cnn(nn.Module):
|
75 |
+
def __init__(self, dropout=0.5):
|
76 |
+
super(Cnn, self).__init__()
|
77 |
+
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
|
78 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
|
79 |
+
self.conv2_drop = nn.Dropout2d(p=dropout)
|
80 |
+
self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height
|
81 |
+
self.fc2 = nn.Linear(100, 10)
|
82 |
+
self.fc1_drop = nn.Dropout(p=dropout)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = torch.relu(F.max_pool2d(self.conv1(x), 2))
|
86 |
+
x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
87 |
+
|
88 |
+
# flatten over channel, height and width = 1600
|
89 |
+
x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
|
90 |
+
|
91 |
+
x = torch.relu(self.fc1_drop(self.fc1(x)))
|
92 |
+
x = torch.softmax(self.fc2(x), dim=-1)
|
93 |
+
return x
|
94 |
+
torch.manual_seed(0)
|
95 |
+
|
96 |
+
|
97 |
+
# # Create an instance of your model
|
98 |
+
# model = NeuralNetClassifier(
|
99 |
+
# Cnn,
|
100 |
+
# max_epochs=10,
|
101 |
+
# lr=0.002,
|
102 |
+
# optimizer=torch.optim.Adam,
|
103 |
+
# device=device,
|
104 |
+
# )
|
105 |
+
model=Cnn()
|
106 |
|
107 |
# Specify the path to the saved model weights
|
108 |
model_weights_path = 'model_weights.pth'
|
109 |
|
110 |
+
# Load the model weights
|
111 |
+
model.load_state_dict(torch.load(model_weights_path,map_location=torch.device('cpu')))
|
112 |
|
113 |
# Set the model to evaluation mode for inference
|
114 |
model.eval()
|
115 |
+
# Create a NeuralNetClassifier using the loaded model
|
116 |
+
cnn = NeuralNetClassifier(
|
117 |
+
module=model,
|
118 |
+
max_epochs=0, # Set max_epochs to 0 to avoid additional training
|
119 |
+
lr=0.002, # You can set this to the learning rate used during training
|
120 |
+
optimizer=torch.optim.Adam, # You can set the optimizer used during training
|
121 |
+
device='cpu' # You can specify the device ('cpu' for CPU, 'cuda' for GPU, etc.)
|
122 |
+
)
|
123 |
|
124 |
stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
|
125 |
stroke_color = st.sidebar.color_picker("Stroke color hex: ")
|
|
|
146 |
|
147 |
# Do something interesting with the image data and paths
|
148 |
if canvas_result.image_data is not None:
|
149 |
+
#st.image(canvas_result.image_data)
|
150 |
image = canvas_result.image_data
|
151 |
image1 = image.copy()
|
152 |
image1 = image1.astype('uint8')
|
153 |
+
image1 = cv2.cvtColor(image1,cv2.COLOR_BGR2GRAY)
|
154 |
+
image1 = cv2.resize(image1,(28,28))
|
155 |
st.image(image1)
|
156 |
|
157 |
+
image1.resize(1,1,28,28)
|
158 |
+
st.title(np.argmax(cnn.predict(image1)))
|
159 |
+
if canvas_result.json_data is not None:
|
160 |
+
st.dataframe(pd.json_normalize(canvas_result.json_data["objects"])) check this code for above error
|
161 |
+
ChatGPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|