import torch.nn as nn import torch.nn.functional as F class CNN(nn.Module): def __init__(self, input_channels): super(CNN, self).__init__() self.input_channels = input_channels self.conv1 = nn.Conv2d(self.input_channels, 32, kernel_size=(3, 3)) self.batchnorm1 = nn.BatchNorm2d(32) self.pool1 = nn.MaxPool2d(kernel_size=(3, 3)) self.dropout1 = nn.Dropout(0.3) self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3)) self.batchnorm2 = nn.BatchNorm2d(64) self.pool2 = nn.MaxPool2d(kernel_size=(1, 3)) self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3)) self.batchnorm3 = nn.BatchNorm2d(128) self.pool3 = nn.MaxPool2d(kernel_size=2) self.dropout2 = nn.Dropout(0.3) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc1 = nn.Linear(128, 256) self.fc2 = nn.Linear(256, 512) self.dropout3 = nn.Dropout(0.5) self.fc3 = nn.Linear(512, 10) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Conv2d): nn.init.xavier_normal_(module.weight.data) if module.bias is not None: nn.init.constant_(module.bias.data, 0) elif isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight.data, 1) nn.init.constant_(module.bias.data, 0) elif isinstance(module, nn.Linear): n = module.in_features y = 1.0 / n ** (1/2) module.weight.data.uniform_(-y, y) module.bias.data.fill_(0) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.batchnorm1(x) x = self.pool1(x) x = self.dropout1(x) x = self.conv2(x) x = F.relu(x) x = self.batchnorm2(x) x = self.pool2(x) x = self.conv3(x) x = F.relu(x) x = self.batchnorm3(x) x = self.pool3(x) x = self.dropout2(x) x = self.avgpool(x) x = x.view(x.size(0), -1) x = self.fc1(x) x = F.relu(x) x = self.dropout3(x) x = self.fc2(x) x = F.relu(x) x = self.dropout3(x) x = self.fc3(x) return x