Spaces:
Running
Running
import torch.nn as nn | |
import torch.nn.functional as F | |
class CNN(nn.Module): | |
def __init__(self, input_channels): | |
super(CNN, self).__init__() | |
self.input_channels = input_channels | |
self.conv1 = nn.Conv2d(self.input_channels, 32, kernel_size=(3, 3)) | |
self.batchnorm1 = nn.BatchNorm2d(32) | |
self.pool1 = nn.MaxPool2d(kernel_size=(3, 3)) | |
self.dropout1 = nn.Dropout(0.3) | |
self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3)) | |
self.batchnorm2 = nn.BatchNorm2d(64) | |
self.pool2 = nn.MaxPool2d(kernel_size=(1, 3)) | |
self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3)) | |
self.batchnorm3 = nn.BatchNorm2d(128) | |
self.pool3 = nn.MaxPool2d(kernel_size=2) | |
self.dropout2 = nn.Dropout(0.3) | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.fc1 = nn.Linear(128, 256) | |
self.fc2 = nn.Linear(256, 512) | |
self.dropout3 = nn.Dropout(0.5) | |
self.fc3 = nn.Linear(512, 10) | |
self.apply(self._init_weights) | |
def _init_weights(self, module): | |
if isinstance(module, nn.Conv2d): | |
nn.init.xavier_normal_(module.weight.data) | |
if module.bias is not None: | |
nn.init.constant_(module.bias.data, 0) | |
elif isinstance(module, nn.BatchNorm2d): | |
nn.init.constant_(module.weight.data, 1) | |
nn.init.constant_(module.bias.data, 0) | |
elif isinstance(module, nn.Linear): | |
n = module.in_features | |
y = 1.0 / n ** (1/2) | |
module.weight.data.uniform_(-y, y) | |
module.bias.data.fill_(0) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.batchnorm1(x) | |
x = self.pool1(x) | |
x = self.dropout1(x) | |
x = self.conv2(x) | |
x = F.relu(x) | |
x = self.batchnorm2(x) | |
x = self.pool2(x) | |
x = self.conv3(x) | |
x = F.relu(x) | |
x = self.batchnorm3(x) | |
x = self.pool3(x) | |
x = self.dropout2(x) | |
x = self.avgpool(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.dropout3(x) | |
x = self.fc2(x) | |
x = F.relu(x) | |
x = self.dropout3(x) | |
x = self.fc3(x) | |
return x | |