import torch import torch.nn as nn class ResBlock(nn.Module): def __init__(self, channels): super(ResBlock, self).__init__() self.resblock = nn.Sequential( nn.Conv2d( in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d( in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(channels), nn.ReLU(), ) def forward(self, x): return x + self.resblock(x) class CustomResnet(nn.Module): def __init__(self): super(CustomResnet, self).__init__() self.prep = nn.Sequential( nn.Conv2d( in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(64), nn.ReLU(), ) self.layer1 = nn.Sequential( nn.Conv2d( in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=1, bias=False, ), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(128), nn.ReLU(), ResBlock(channels=128), ) self.layer2 = nn.Sequential( nn.Conv2d( in_channels=128, out_channels=256, kernel_size=3, padding=1, stride=1, bias=False, ), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(256), nn.ReLU(), ) self.layer3 = nn.Sequential( nn.Conv2d( in_channels=256, out_channels=512, kernel_size=3, padding=1, stride=1, bias=False, ), nn.MaxPool2d(kernel_size=2), nn.BatchNorm2d(512), nn.ReLU(), ResBlock(channels=512), ) self.pool = nn.MaxPool2d(kernel_size=4) self.fc = nn.Linear(in_features=512, out_features=10, bias=False) self.softmax = nn.Softmax(dim=-1) def forward(self, x): x = self.prep(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.pool(x) x = x.view(-1, 512) x = self.fc(x) # x = self.softmax(x) return x