from torch import nn class NeuralNetwork(nn.Module): def __init__(self, l1=120, l2=84): super(NeuralNetwork, self).__init__() self.conv_relu_stack = nn.Sequential( nn.Conv2d(3, 6, (5, 5)), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Conv2d(6, 16, (5, 5)), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.linear_relu_stack = nn.Sequential( nn.Linear(16*(5**2), l1), nn.ReLU(), nn.Linear(l1, l2), nn.ReLU(), nn.Linear(l2, 10), ) def forward(self, x): x = self.conv_relu_stack(x) x = x.view(-1, 16 * (5 ** 2)) return self.linear_relu_stack(x)