26 lines
726 B
Python
26 lines
726 B
Python
from torch import nn
|
|
|
|
|
|
class NeuralNetwork(nn.Module):
|
|
def __init__(self, l1=120, l2=84):
|
|
super(NeuralNetwork, self).__init__()
|
|
self.conv_relu_stack = nn.Sequential(
|
|
nn.Conv2d(3, 6, (5, 5)),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2, 2),
|
|
nn.Conv2d(6, 16, (5, 5)),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2, 2)
|
|
)
|
|
self.linear_relu_stack = nn.Sequential(
|
|
nn.Linear(16*(5**2), l1),
|
|
nn.ReLU(),
|
|
nn.Linear(l1, l2),
|
|
nn.ReLU(),
|
|
nn.Linear(l2, 10),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv_relu_stack(x)
|
|
x = x.view(-1, 16 * (5 ** 2))
|
|
return self.linear_relu_stack(x)
|