File size: 3,665 Bytes
be853dd
 
 
 
 
 
 
669454f
be853dd
669454f
 
be853dd
669454f
 
 
 
 
 
be853dd
 
669454f
 
 
be853dd
669454f
be853dd
 
 
 
 
 
669454f
be853dd
 
669454f
be853dd
 
669454f
be853dd
 
669454f
 
 
 
be853dd
669454f
be853dd
 
 
 
669454f
be853dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669454f
be853dd
 
 
 
 
669454f
 
 
 
 
 
 
 
a8d680c
be853dd
669454f
 
 
 
 
 
 
 
be853dd
 
669454f
be853dd
 
 
a8d680c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# Define the CNN
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Function to train the model
def train_model(num_epochs):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

    CIFAR10_CLASSES = [
        'plane', 'car', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck'
    ]

    net = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

    loss_values = []
    st.write("Training the model...")

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        loss_values.append(running_loss / len(trainloader))
        st.write(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}')
    st.write('Finished Training')

    # Plot the loss values
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, num_epochs + 1), loss_values, marker='o')
    plt.title('Training Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    st.pyplot(plt)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    st.write(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')

    # Visualize some test images and their predictions
    def imshow(img):
        img = img / 2 + 0.5  # Unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    dataiter = iter(testloader)
    images, labels = next(dataiter)

    imshow(torchvision.utils.make_grid(images))

    outputs = net(images)
    _, predicted = torch.max(outputs, 1)

    st.write('Predicted: ', ' '.join(f'{CIFAR10_CLASSES[predicted[j]]:5s}' for j in range(8)))
    st.write('Actual:    ', ' '.join(f'{CIFAR10_CLASSES[labels[j]]:5s}' for j in range(8)))
    st.pyplot()

# Streamlit interface
st.title('CIFAR-10 Classification with PyTorch')
num_epochs = st.number_input('Enter number of epochs:', min_value=1, max_value=100, value=10)
if st.button('Run'):
    train_model(num_epochs)