from sklearn.datasets import fetch_openml from sklearn.model_selection import train_test_split import numpy as np import torch from skorch import NeuralNetClassifier from torch import nn import torch.nn.functional as F import matplotlib.pyplot as plt mnist = fetch_openml('mnist_784', as_frame=False, cache=False) X = mnist.data.astype('float32') y = mnist.target.astype('int64') X /= 255.0 device = 'cuda' if torch.cuda.is_available() else 'cpu' XCnn = X.reshape(-1, 1, 28, 28) XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42) from PIL import Image import torchvision.transforms as transforms class Cnn(nn.Module): def __init__(self, dropout=0.5): super(Cnn, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.conv2 = nn.Conv2d(32, 64, kernel_size=3) self.conv2_drop = nn.Dropout2d(p=dropout) self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height self.fc2 = nn.Linear(100, 10) self.fc1_drop = nn.Dropout(p=dropout) def forward(self, x): x = torch.relu(F.max_pool2d(self.conv1(x), 2)) x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) # flatten over channel, height and width = 1600 x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) x = torch.relu(self.fc1_drop(self.fc1(x))) x = torch.softmax(self.fc2(x), dim=-1) return x torch.manual_seed(0) cnn = NeuralNetClassifier( Cnn, max_epochs=10, lr=0.002, optimizer=torch.optim.Adam, device=device, ) cnn.fit(XCnn_train, y_train) # Specify the path to save the model weights # After training, save the model weights import torch # After training, save the model weights model_weights_path = 'model_weights.pth' torch.save(cnn.module_.state_dict(), model_weights_path) stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32) stroke_color = st.sidebar.color_picker("Stroke color hex: ") bg_color = st.sidebar.color_picker("Background color hex: ", "#eee") bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"]) drawing_mode = st.sidebar.selectbox( "Drawing tool:", ("freedraw", "line", "rect", "circle", "transform", "polygon") ) realtime_update = st.sidebar.checkbox("Update in realtime", True) # Create a canvas component canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=stroke_width, stroke_color=stroke_color, background_color=bg_color, background_image=Image.open(bg_image) if bg_image else None, update_streamlit=realtime_update, height=300, drawing_mode=drawing_mode, display_toolbar=st.sidebar.checkbox("Display toolbar", True), key="full_app", ) # Do something interesting with the image data and paths if canvas_result.image_data is not None: #st.image(canvas_result.image_data) image = canvas_result.image_data image1 = image.copy() image1 = image1.astype('uint8') image1 = cv2.cvtColor(image1,cv2.COLOR_BGR2GRAY) image1 = cv2.resize(image1,(28,28)) st.image(image1) image1.resize(1,28,28,1) st.title(np.argmax(cnn.predict(image1))) if canvas_result.json_data is not None: st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))