|
from torchvision.transforms import ToTensor |
|
from torchvision.transforms import v2 |
|
from torchvision import transforms |
|
|
|
import matplotlib.pyplot as plt |
|
from time import time |
|
from torch import nn |
|
import pandas as pd |
|
import numpy as np |
|
import torch, os |
|
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
from tqdm import tqdm |
|
|
|
|
|
input_shape = (224, 224, 3) |
|
|
|
device = ( |
|
"cuda" |
|
if torch.cuda.is_available() |
|
else "mps" |
|
if torch.backends.mps.is_available() |
|
else "cpu" |
|
) |
|
|
|
|
|
class MakiAlexNet(nn.Module): |
|
def __init__(self, num_classes=2): |
|
super(MakiAlexNet, self).__init__() |
|
self.num_classes = num_classes |
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=1) |
|
self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2) |
|
self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1) |
|
self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1) |
|
self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1) |
|
self.activation = nn.ReLU() |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2) |
|
self.dropout = nn.Dropout(p=0.5) |
|
self.f_linear = nn.Linear(256, self.num_classes) |
|
|
|
|
|
self.gap = nn.AvgPool2d(5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.layer_outputs = {} |
|
|
|
|
|
self.conv5.register_forward_hook(self._save_layer_output) |
|
self.f_linear.register_forward_hook(self._save_layer_output) |
|
|
|
|
|
def _save_to_output_weights(self, module, input, output): |
|
self.layer_outputs[module.__class__.__name__] = {"input": input, "output": output, "weights": module.weight.data} |
|
|
|
|
|
def _save_layer_output(self, module, input, output): |
|
self.layer_outputs[module.__class__.__name__] = output |
|
|
|
def forward(self, x): |
|
"""Defined forward pass of AlexNet for learning left or right prediction.""" |
|
x = self.conv1(x) |
|
x = self.activation(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.conv2(x) |
|
x = self.activation(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.conv3(x) |
|
x = self.activation(x) |
|
|
|
x = self.conv4(x) |
|
x = self.activation(x) |
|
|
|
x = self.conv5(x) |
|
x = self.activation(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.gap(x).squeeze(-1).squeeze(-1) |
|
|
|
x = self.dropout(x) |
|
x = self.f_linear(x) |
|
return x |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from dataset_creation import test_loader, train_loader |
|
|
|
|
|
|
|
EPOCH = 35 |
|
model = MakiAlexNet() |
|
model.to(device) |
|
print(model) |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.00001 * 5, weight_decay=0.0001, momentum=0.9) |
|
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) |
|
if os.path.exists("best_model_2.0.txt"): |
|
with open("best_model_2.0.txt", "r") as file: |
|
best_accuracy = float(file.read()) |
|
else: |
|
best_accuracy = 0.0 |
|
|
|
with open("best_model_2.0.txt", "w") as file: |
|
file.write(f"{best_accuracy}") |
|
|
|
for epoch in tqdm(range(EPOCH), desc="Training Epoch Cycle"): |
|
model.train() |
|
running_loss = 0.0 |
|
|
|
for i, data in enumerate(train_loader, 0): |
|
if i % 10 == 0: |
|
print(f"Internal Loop of batches: {i}") |
|
inputs, labels = data |
|
|
|
inputs, labels = inputs.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
|
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
|
|
train_loss = running_loss / len(train_loader) |
|
print(f'Epoch [{epoch + 1}] training loss: {train_loss:.3f}') |
|
|
|
|
|
model.eval() |
|
val_running_loss = 0.0 |
|
val_correct = 0 |
|
val_total = 0 |
|
with torch.no_grad(): |
|
for data in test_loader: |
|
inputs, labels = data |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
|
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
|
|
val_running_loss += loss.item() |
|
_, predicted = torch.max(outputs.data, 1) |
|
val_total += labels.size(0) |
|
val_correct += (predicted == labels).sum().item() |
|
|
|
val_loss = val_running_loss / len(test_loader) |
|
val_accuracy = 100 * val_correct / val_total |
|
print(f'Epoch [{epoch + 1}] validation loss: {val_loss:.3f}, accuracy: {val_accuracy:.2f}%') |
|
if val_accuracy > best_accuracy: |
|
best_accuracy = val_accuracy |
|
torch.save(model.state_dict(), "alexnet_2.0.pth") |
|
with open("best_model_2.0.txt", "w") as file: |
|
file.write(f"{best_accuracy}") |
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|