Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -40,26 +40,18 @@ class Cnn(nn.Module):
40
  return x
41
  torch.manual_seed(0)
42
 
43
- cnn = NeuralNetClassifier(
44
- Cnn,
45
- max_epochs=10,
46
- lr=0.002,
47
- optimizer=torch.optim.Adam,
48
- device=device,
49
- )
50
- cnn.fit(XCnn_train, y_train)
51
- # Specify the path to save the model weights
52
- # After training, save the model weights
53
- import torch
54
 
 
 
55
 
56
-
57
- # After training, save the model weights
58
  model_weights_path = 'model_weights.pth'
59
 
60
- torch.save(cnn.module_.state_dict(), model_weights_path)
61
-
62
 
 
 
63
 
64
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
65
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
@@ -95,6 +87,6 @@ if canvas_result.image_data is not None:
95
  st.image(image1)
96
 
97
  image1.resize(1,1,28,28)
98
- st.title(np.argmax(cnn.predict(image1)))
99
  if canvas_result.json_data is not None:
100
  st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))
 
40
  return x
41
  torch.manual_seed(0)
42
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ # Create an instance of your model
45
+ model = Cnn()
46
 
47
+ # Specify the path to the saved model weights
 
48
  model_weights_path = 'model_weights.pth'
49
 
50
+ # Load the model weights
51
+ model.load_state_dict(torch.load(model_weights_path))
52
 
53
+ # Set the model to evaluation mode for inference
54
+ model.eval()
55
 
56
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
57
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
 
87
  st.image(image1)
88
 
89
  image1.resize(1,1,28,28)
90
+ st.title(np.argmax(model.predict(image1)))
91
  if canvas_result.json_data is not None:
92
  st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))