import torch.nn as nn class NeuralNet(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.layers = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self, x): x = self.flatten(x) x = self.layers(x) return x