Gosula commited on
Commit
2db032e
1 Parent(s): e436ba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -52,7 +52,13 @@ torch.manual_seed(0)
52
 
53
 
54
  # Create an instance of your model
55
- model = Cnn()
 
 
 
 
 
 
56
 
57
  # Specify the path to the saved model weights
58
  model_weights_path = 'model_weights.pth'
@@ -62,6 +68,14 @@ model.load_state_dict(torch.load(model_weights_path,map_location=torch.device('c
62
 
63
  # Set the model to evaluation mode for inference
64
  model.eval()
 
 
 
 
 
 
 
 
65
 
66
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
67
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
@@ -97,6 +111,6 @@ if canvas_result.image_data is not None:
97
  st.image(image1)
98
 
99
  image1.resize(1,1,28,28)
100
- st.title(np.argmax(model.predict(image1)))
101
  if canvas_result.json_data is not None:
102
  st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))
 
52
 
53
 
54
  # Create an instance of your model
55
+ model = NeuralNetClassifier(
56
+ Cnn,
57
+ max_epochs=10,
58
+ lr=0.002,
59
+ optimizer=torch.optim.Adam,
60
+ device=device,
61
+ )
62
 
63
  # Specify the path to the saved model weights
64
  model_weights_path = 'model_weights.pth'
 
68
 
69
  # Set the model to evaluation mode for inference
70
  model.eval()
71
+ # Create a NeuralNetClassifier using the loaded model
72
+ cnn = NeuralNetClassifier(
73
+ module=model,
74
+ max_epochs=0, # Set max_epochs to 0 to avoid additional training
75
+ lr=0.002, # You can set this to the learning rate used during training
76
+ optimizer=torch.optim.Adam, # You can set the optimizer used during training
77
+ device='cpu' # You can specify the device ('cpu' for CPU, 'cuda' for GPU, etc.)
78
+ )
79
 
80
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
81
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
 
111
  st.image(image1)
112
 
113
  image1.resize(1,1,28,28)
114
+ st.title(np.argmax(cnn.predict(image1)))
115
  if canvas_result.json_data is not None:
116
  st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))