From b91008689319bb3cca50e41dee67fa847e2d7f53 Mon Sep 17 00:00:00 2001 From: flifloo Date: Mon, 29 Mar 2021 10:33:21 +0200 Subject: [PATCH] Second attempt of nn with convolution --- main.py | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index a52bdb1..cc9f073 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,6 @@ from os.path import isfile import torch -from numpy import prod from torch import nn from torch.utils.data import DataLoader from torchvision import datasets @@ -36,33 +35,31 @@ def get_data(batch_size: int = 64): return train_dataloader, test_dataloader -def generate_layers(inp: int, output: int): - layers = 2 - conns = (inp+output)*2 - stack = [nn.Linear(inp, conns), nn.ReLU()] - - print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}") - - print("Generating stack...") - for _ in range(layers): - stack.append(nn.Linear(conns, conns)) - stack.append(nn.ReLU()) - - stack += [nn.Linear(conns, output), nn.ReLU()] - - print("Stack generated") - return stack - - # Define model class NeuralNetwork(nn.Module): - def __init__(self, stack): + def __init__(self): super(NeuralNetwork, self).__init__() - self.flatten = nn.Flatten() - self.linear_relu_stack = nn.Sequential(*stack) + self.conv_relu_stack = nn.Sequential( + nn.Conv2d(3, 6, (5, 5)), + nn.MaxPool2d(2, 2), + nn.ReLU(), + nn.Conv2d(6, 16, (5, 5)), + nn.MaxPool2d(2, 2), + nn.ReLU(), + ) + self.linear_relu_stack = nn.Sequential( + nn.Linear(16*(5**2), 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, 10), + nn.ReLU(), + ) def forward(self, x): - return self.linear_relu_stack(self.flatten(x)) + x = self.conv_relu_stack(x) + x = x.view(-1, 16 * 5 * 5) + return self.linear_relu_stack(x) def train(dataloader, model, loss_fn, optimizer): @@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn): def training(): train_data, test_data = get_data() - stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes)) - model = NeuralNetwork(stack).to(device) + model = NeuralNetwork().to(device) if isfile("model.pth"): print("Loading model from save") model.load_state_dict(torch.load("model.pth"))