Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision | |
import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Define the CNN | |
class SimpleCNN(nn.Module): | |
def __init__(self): | |
super(SimpleCNN, self).__init__() | |
self.conv1 = nn.Conv2d(3, 16, 3, padding=1) | |
self.conv2 = nn.Conv2d(16, 32, 3, padding=1) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.fc1 = nn.Linear(32 * 8 * 8, 128) | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x): | |
x = self.pool(torch.relu(self.conv1(x))) | |
x = self.pool(torch.relu(self.conv2(x))) | |
x = x.view(-1, 32 * 8 * 8) | |
x = torch.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
# Function to train the model | |
def train_model(num_epochs): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) | |
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) | |
CIFAR10_CLASSES = [ | |
'plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck' | |
] | |
net = SimpleCNN() | |
criterion = nn.CrossEntropyLoss() | |
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | |
loss_values = [] | |
st.write("Training the model...") | |
for epoch in range(num_epochs): | |
running_loss = 0.0 | |
for i, data in enumerate(trainloader, 0): | |
inputs, labels = data | |
optimizer.zero_grad() | |
outputs = net(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
loss_values.append(running_loss / len(trainloader)) | |
st.write(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}') | |
st.write('Finished Training') | |
# Plot the loss values | |
plt.figure(figsize=(10, 5)) | |
plt.plot(range(1, num_epochs + 1), loss_values, marker='o') | |
plt.title('Training Loss over Epochs') | |
plt.xlabel('Epoch') | |
plt.ylabel('Loss') | |
st.pyplot(plt) | |
correct = 0 | |
total = 0 | |
with torch.no_grad(): | |
for data in testloader: | |
images, labels = data | |
outputs = net(images) | |
_, predicted = torch.max(outputs, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
st.write(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%') | |
# Visualize some test images and their predictions | |
def imshow(img): | |
img = img / 2 + 0.5 # Unnormalize | |
npimg = img.numpy() | |
plt.imshow(np.transpose(npimg, (1, 2, 0))) | |
plt.show() | |
dataiter = iter(testloader) | |
images, labels = next(dataiter) | |
imshow(torchvision.utils.make_grid(images)) | |
outputs = net(images) | |
_, predicted = torch.max(outputs, 1) | |
st.write('Predicted: ', ' '.join(f'{CIFAR10_CLASSES[predicted[j]]:5s}' for j in range(8))) | |
st.write('Actual: ', ' '.join(f'{CIFAR10_CLASSES[labels[j]]:5s}' for j in range(8))) | |
st.pyplot() | |
# Streamlit interface | |
st.title('CIFAR-10 Classification with PyTorch') | |
num_epochs = st.number_input('Enter number of epochs:', min_value=1, max_value=100, value=10) | |
if st.button('Run'): | |
train_model(num_epochs) | |