Gosula commited on
Commit
39cdc68
1 Parent(s): 5c5075c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.datasets import fetch_openml
2
+ from sklearn.model_selection import train_test_split
3
+ import numpy as np
4
+ import torch
5
+ from skorch import NeuralNetClassifier
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ import matplotlib.pyplot as plt
9
+ mnist = fetch_openml('mnist_784', as_frame=False, cache=False)
10
+ X = mnist.data.astype('float32')
11
+ y = mnist.target.astype('int64')
12
+ X /= 255.0
13
+
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ XCnn = X.reshape(-1, 1, 28, 28)
17
+ XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)
18
+
19
+ from PIL import Image
20
+ import torchvision.transforms as transforms
21
+ class Cnn(nn.Module):
22
+ def __init__(self, dropout=0.5):
23
+ super(Cnn, self).__init__()
24
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
25
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
26
+ self.conv2_drop = nn.Dropout2d(p=dropout)
27
+ self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height
28
+ self.fc2 = nn.Linear(100, 10)
29
+ self.fc1_drop = nn.Dropout(p=dropout)
30
+
31
+ def forward(self, x):
32
+ x = torch.relu(F.max_pool2d(self.conv1(x), 2))
33
+ x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
34
+
35
+ # flatten over channel, height and width = 1600
36
+ x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
37
+
38
+ x = torch.relu(self.fc1_drop(self.fc1(x)))
39
+ x = torch.softmax(self.fc2(x), dim=-1)
40
+ return x
41
+ torch.manual_seed(0)
42
+
43
+ cnn = NeuralNetClassifier(
44
+ Cnn,
45
+ max_epochs=10,
46
+ lr=0.002,
47
+ optimizer=torch.optim.Adam,
48
+ device=device,
49
+ )
50
+ cnn.fit(XCnn_train, y_train)
51
+ # Specify the path to save the model weights
52
+ # After training, save the model weights
53
+ import torch
54
+
55
+
56
+
57
+ # After training, save the model weights
58
+ model_weights_path = 'model_weights.pth'
59
+
60
+ torch.save(cnn.module_.state_dict(), model_weights_path)
61
+
62
+
63
+
64
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
65
+ stroke_color = st.sidebar.color_picker("Stroke color hex: ")
66
+ bg_color = st.sidebar.color_picker("Background color hex: ", "#eee")
67
+ bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
68
+ drawing_mode = st.sidebar.selectbox(
69
+ "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform", "polygon")
70
+ )
71
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
72
+
73
+ # Create a canvas component
74
+ canvas_result = st_canvas(
75
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
76
+ stroke_width=stroke_width,
77
+ stroke_color=stroke_color,
78
+ background_color=bg_color,
79
+ background_image=Image.open(bg_image) if bg_image else None,
80
+ update_streamlit=realtime_update,
81
+ height=300,
82
+ drawing_mode=drawing_mode,
83
+ display_toolbar=st.sidebar.checkbox("Display toolbar", True),
84
+ key="full_app",
85
+ )
86
+
87
+ # Do something interesting with the image data and paths
88
+ if canvas_result.image_data is not None:
89
+ #st.image(canvas_result.image_data)
90
+ image = canvas_result.image_data
91
+ image1 = image.copy()
92
+ image1 = image1.astype('uint8')
93
+ image1 = cv2.cvtColor(image1,cv2.COLOR_BGR2GRAY)
94
+ image1 = cv2.resize(image1,(28,28))
95
+ st.image(image1)
96
+
97
+ image1.resize(1,28,28,1)
98
+ st.title(np.argmax(cnn.predict(image1)))
99
+ if canvas_result.json_data is not None:
100
+ st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))