|
class Cnn(nn.Module): |
|
def __init__(self, dropout=0.5): |
|
super(Cnn, self).__init__() |
|
self.conv1 = nn.Conv2d(1, 32, kernel_size=3) |
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=3) |
|
self.conv2_drop = nn.Dropout2d(p=dropout) |
|
self.fc1 = nn.Linear(1600, 100) |
|
self.fc2 = nn.Linear(100, 10) |
|
self.fc1_drop = nn.Dropout(p=dropout) |
|
|
|
def forward(self, x): |
|
x = torch.relu(F.max_pool2d(self.conv1(x), 2)) |
|
x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) |
|
|
|
|
|
x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) |
|
|
|
x = torch.relu(self.fc1_drop(self.fc1(x))) |
|
x = torch.softmax(self.fc2(x), dim=-1) |
|
return x |