ERA-Session12 / models /resnet_lightning.py
ravi.naik
Added source
4db4d66
raw
history blame contribute delete
No virus
4.94 kB
import torch
import torch.nn as nn
import lightning as L
from torchmetrics import Accuracy
from typing import Any
from utils.common import one_cycle_lr
class ResidualBlock(L.LightningModule):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.residual_block = nn.Sequential(
nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(channels),
nn.ReLU(),
nn.Conv2d(
in_channels=channels,
out_channels=channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(channels),
nn.ReLU(),
)
def forward(self, x):
return x + self.residual_block(x)
class ResNet(L.LightningModule):
def __init__(
self, batch_size=512, shuffle=True, num_workers=4, learning_rate=0.003, scheduler_steps=None, maxlr=None, epochs=None
):
super(ResNet, self).__init__()
self.data_dir = "./data"
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.learning_rate = learning_rate
self.scheduler_steps = scheduler_steps
self.maxlr = maxlr if maxlr is not None else learning_rate
self.epochs = epochs
self.prep = nn.Sequential(
nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False,
),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.layer1 = nn.Sequential(
nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
padding=1,
stride=1,
bias=False,
),
nn.MaxPool2d(kernel_size=2),
nn.BatchNorm2d(128),
nn.ReLU(),
ResidualBlock(channels=128),
)
self.layer2 = nn.Sequential(
nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
padding=1,
stride=1,
bias=False,
),
nn.MaxPool2d(kernel_size=2),
nn.BatchNorm2d(256),
nn.ReLU(),
)
self.layer3 = nn.Sequential(
nn.Conv2d(
in_channels=256,
out_channels=512,
kernel_size=3,
padding=1,
stride=1,
bias=False,
),
nn.MaxPool2d(kernel_size=2),
nn.BatchNorm2d(512),
nn.ReLU(),
ResidualBlock(channels=512),
)
self.pool = nn.MaxPool2d(kernel_size=4)
self.fc = nn.Linear(in_features=512, out_features=10, bias=False)
self.softmax = nn.Softmax(dim=-1)
self.accuracy = Accuracy(task="multiclass", num_classes=10)
def forward(self, x):
x = self.prep(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.pool(x)
x = x.view(-1, 512)
x = self.fc(x)
# x = self.softmax(x)
return x
def configure_optimizers(self) -> Any:
optimizer = torch.optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=1e-4
)
scheduler = one_cycle_lr(
optimizer=optimizer, maxlr=self.maxlr, steps=self.scheduler_steps, epochs=self.epochs
)
return {"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler,
"interval": "step"}}
def training_step(self, batch, batch_idx):
X, y = batch
y_pred = self(X)
loss = nn.CrossEntropyLoss()(y_pred, y)
preds = torch.argmax(y_pred, dim=1)
accuracy = self.accuracy(preds, y)
self.log_dict({"train_loss": loss, "train_acc": accuracy}, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
X, y = batch
y_pred = self(X)
loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
preds = torch.argmax(y_pred, dim=1)
accuracy = self.accuracy(preds, y)
self.log_dict({"val_loss": loss, "val_acc": accuracy}, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
X, y = batch
y_pred = self(X)
loss = nn.CrossEntropyLoss(reduction="sum")(y_pred, y)
preds = torch.argmax(y_pred, dim=1)
accuracy = self.accuracy(preds, y)
self.log_dict({"test_loss": loss, "test_acc": accuracy}, prog_bar=True)