1
0
Fork 0
This repository has been archived on 2024-02-17. You can view files and clone it, but cannot push or open issues or pull requests.
TP_IA/NeuralNetwork.py

27 lines
726 B
Python

from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self, l1=120, l2=84):
super(NeuralNetwork, self).__init__()
self.conv_relu_stack = nn.Sequential(
nn.Conv2d(3, 6, (5, 5)),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, (5, 5)),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.linear_relu_stack = nn.Sequential(
nn.Linear(16*(5**2), l1),
nn.ReLU(),
nn.Linear(l1, l2),
nn.ReLU(),
nn.Linear(l2, 10),
)
def forward(self, x):
x = self.conv_relu_stack(x)
x = x.view(-1, 16 * (5 ** 2))
return self.linear_relu_stack(x)