1
0
Fork 0

Second attempt of nn with convolution

This commit is contained in:
Ethanell 2021-03-29 10:33:21 +02:00
parent 3933b77da1
commit b910086893

46
main.py
View file

@ -1,7 +1,6 @@
from os.path import isfile from os.path import isfile
import torch import torch
from numpy import prod
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import datasets from torchvision import datasets
@ -36,33 +35,31 @@ def get_data(batch_size: int = 64):
return train_dataloader, test_dataloader return train_dataloader, test_dataloader
def generate_layers(inp: int, output: int):
layers = 2
conns = (inp+output)*2
stack = [nn.Linear(inp, conns), nn.ReLU()]
print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}")
print("Generating stack...")
for _ in range(layers):
stack.append(nn.Linear(conns, conns))
stack.append(nn.ReLU())
stack += [nn.Linear(conns, output), nn.ReLU()]
print("Stack generated")
return stack
# Define model # Define model
class NeuralNetwork(nn.Module): class NeuralNetwork(nn.Module):
def __init__(self, stack): def __init__(self):
super(NeuralNetwork, self).__init__() super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten() self.conv_relu_stack = nn.Sequential(
self.linear_relu_stack = nn.Sequential(*stack) nn.Conv2d(3, 6, (5, 5)),
nn.MaxPool2d(2, 2),
nn.ReLU(),
nn.Conv2d(6, 16, (5, 5)),
nn.MaxPool2d(2, 2),
nn.ReLU(),
)
self.linear_relu_stack = nn.Sequential(
nn.Linear(16*(5**2), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
nn.ReLU(),
)
def forward(self, x): def forward(self, x):
return self.linear_relu_stack(self.flatten(x)) x = self.conv_relu_stack(x)
x = x.view(-1, 16 * 5 * 5)
return self.linear_relu_stack(x)
def train(dataloader, model, loss_fn, optimizer): def train(dataloader, model, loss_fn, optimizer):
@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn):
def training(): def training():
train_data, test_data = get_data() train_data, test_data = get_data()
stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes)) model = NeuralNetwork().to(device)
model = NeuralNetwork(stack).to(device)
if isfile("model.pth"): if isfile("model.pth"):
print("Loading model from save") print("Loading model from save")
model.load_state_dict(torch.load("model.pth")) model.load_state_dict(torch.load("model.pth"))