Second attempt of nn with convolution
This commit is contained in:
parent
3933b77da1
commit
b910086893
1 changed files with 21 additions and 25 deletions
46
main.py
46
main.py
|
@ -1,7 +1,6 @@
|
||||||
from os.path import isfile
|
from os.path import isfile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from numpy import prod
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
|
@ -36,33 +35,31 @@ def get_data(batch_size: int = 64):
|
||||||
return train_dataloader, test_dataloader
|
return train_dataloader, test_dataloader
|
||||||
|
|
||||||
|
|
||||||
def generate_layers(inp: int, output: int):
|
|
||||||
layers = 2
|
|
||||||
conns = (inp+output)*2
|
|
||||||
stack = [nn.Linear(inp, conns), nn.ReLU()]
|
|
||||||
|
|
||||||
print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}")
|
|
||||||
|
|
||||||
print("Generating stack...")
|
|
||||||
for _ in range(layers):
|
|
||||||
stack.append(nn.Linear(conns, conns))
|
|
||||||
stack.append(nn.ReLU())
|
|
||||||
|
|
||||||
stack += [nn.Linear(conns, output), nn.ReLU()]
|
|
||||||
|
|
||||||
print("Stack generated")
|
|
||||||
return stack
|
|
||||||
|
|
||||||
|
|
||||||
# Define model
|
# Define model
|
||||||
class NeuralNetwork(nn.Module):
|
class NeuralNetwork(nn.Module):
|
||||||
def __init__(self, stack):
|
def __init__(self):
|
||||||
super(NeuralNetwork, self).__init__()
|
super(NeuralNetwork, self).__init__()
|
||||||
self.flatten = nn.Flatten()
|
self.conv_relu_stack = nn.Sequential(
|
||||||
self.linear_relu_stack = nn.Sequential(*stack)
|
nn.Conv2d(3, 6, (5, 5)),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(6, 16, (5, 5)),
|
||||||
|
nn.MaxPool2d(2, 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.linear_relu_stack = nn.Sequential(
|
||||||
|
nn.Linear(16*(5**2), 120),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(120, 84),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(84, 10),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear_relu_stack(self.flatten(x))
|
x = self.conv_relu_stack(x)
|
||||||
|
x = x.view(-1, 16 * 5 * 5)
|
||||||
|
return self.linear_relu_stack(x)
|
||||||
|
|
||||||
|
|
||||||
def train(dataloader, model, loss_fn, optimizer):
|
def train(dataloader, model, loss_fn, optimizer):
|
||||||
|
@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn):
|
||||||
def training():
|
def training():
|
||||||
train_data, test_data = get_data()
|
train_data, test_data = get_data()
|
||||||
|
|
||||||
stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes))
|
model = NeuralNetwork().to(device)
|
||||||
model = NeuralNetwork(stack).to(device)
|
|
||||||
if isfile("model.pth"):
|
if isfile("model.pth"):
|
||||||
print("Loading model from save")
|
print("Loading model from save")
|
||||||
model.load_state_dict(torch.load("model.pth"))
|
model.load_state_dict(torch.load("model.pth"))
|
||||||
|
|
Reference in a new issue