import streamlit as st import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np # Define the CNN class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(32 * 8 * 8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = self.pool(torch.relu(self.conv2(x))) x = x.view(-1, 32 * 8 * 8) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # Function to train the model def train_model(num_epochs): transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) CIFAR10_CLASSES = [ 'plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] net = SimpleCNN() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) loss_values = [] st.write("Training the model...") for epoch in range(num_epochs): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() loss_values.append(running_loss / len(trainloader)) st.write(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}') st.write('Finished Training') # Plot the loss values plt.figure(figsize=(10, 5)) plt.plot(range(1, num_epochs + 1), loss_values, marker='o') plt.title('Training Loss over Epochs') plt.xlabel('Epoch') plt.ylabel('Loss') st.pyplot(plt) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() st.write(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%') # Visualize some test images and their predictions def imshow(img): img = img / 2 + 0.5 # Unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() dataiter = iter(testloader) images, labels = next(dataiter) imshow(torchvision.utils.make_grid(images)) outputs = net(images) _, predicted = torch.max(outputs, 1) st.write('Predicted: ', ' '.join(f'{CIFAR10_CLASSES[predicted[j]]:5s}' for j in range(8))) st.write('Actual: ', ' '.join(f'{CIFAR10_CLASSES[labels[j]]:5s}' for j in range(8))) st.pyplot() # Streamlit interface st.title('CIFAR-10 Classification with PyTorch') num_epochs = st.number_input('Enter number of epochs:', min_value=1, max_value=100, value=10) if st.button('Run'): train_model(num_epochs)