1
0
Fork 0

Second attempt of nn with convolution

This commit is contained in:
Ethanell 2021-03-29 10:33:21 +02:00
parent 3933b77da1
commit b910086893

46
main.py
View file

@ -1,7 +1,6 @@
from os.path import isfile
import torch
from numpy import prod
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
@ -36,33 +35,31 @@ def get_data(batch_size: int = 64):
return train_dataloader, test_dataloader
def generate_layers(inp: int, output: int):
layers = 2
conns = (inp+output)*2
stack = [nn.Linear(inp, conns), nn.ReLU()]
print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}")
print("Generating stack...")
for _ in range(layers):
stack.append(nn.Linear(conns, conns))
stack.append(nn.ReLU())
stack += [nn.Linear(conns, output), nn.ReLU()]
print("Stack generated")
return stack
# Define model
class NeuralNetwork(nn.Module):
def __init__(self, stack):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(*stack)
self.conv_relu_stack = nn.Sequential(
nn.Conv2d(3, 6, (5, 5)),
nn.MaxPool2d(2, 2),
nn.ReLU(),
nn.Conv2d(6, 16, (5, 5)),
nn.MaxPool2d(2, 2),
nn.ReLU(),
)
self.linear_relu_stack = nn.Sequential(
nn.Linear(16*(5**2), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
nn.ReLU(),
)
def forward(self, x):
return self.linear_relu_stack(self.flatten(x))
x = self.conv_relu_stack(x)
x = x.view(-1, 16 * 5 * 5)
return self.linear_relu_stack(x)
def train(dataloader, model, loss_fn, optimizer):
@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn):
def training():
train_data, test_data = get_data()
stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes))
model = NeuralNetwork(stack).to(device)
model = NeuralNetwork().to(device)
if isfile("model.pth"):
print("Loading model from save")
model.load_state_dict(torch.load("model.pth"))