Second attempt of nn with convolution
This commit is contained in:
parent
3933b77da1
commit
b910086893
1 changed files with 21 additions and 25 deletions
46
main.py
46
main.py
|
@ -1,7 +1,6 @@
|
|||
from os.path import isfile
|
||||
|
||||
import torch
|
||||
from numpy import prod
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets
|
||||
|
@ -36,33 +35,31 @@ def get_data(batch_size: int = 64):
|
|||
return train_dataloader, test_dataloader
|
||||
|
||||
|
||||
def generate_layers(inp: int, output: int):
|
||||
layers = 2
|
||||
conns = (inp+output)*2
|
||||
stack = [nn.Linear(inp, conns), nn.ReLU()]
|
||||
|
||||
print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}")
|
||||
|
||||
print("Generating stack...")
|
||||
for _ in range(layers):
|
||||
stack.append(nn.Linear(conns, conns))
|
||||
stack.append(nn.ReLU())
|
||||
|
||||
stack += [nn.Linear(conns, output), nn.ReLU()]
|
||||
|
||||
print("Stack generated")
|
||||
return stack
|
||||
|
||||
|
||||
# Define model
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self, stack):
|
||||
def __init__(self):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
self.flatten = nn.Flatten()
|
||||
self.linear_relu_stack = nn.Sequential(*stack)
|
||||
self.conv_relu_stack = nn.Sequential(
|
||||
nn.Conv2d(3, 6, (5, 5)),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(6, 16, (5, 5)),
|
||||
nn.MaxPool2d(2, 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.linear_relu_stack = nn.Sequential(
|
||||
nn.Linear(16*(5**2), 120),
|
||||
nn.ReLU(),
|
||||
nn.Linear(120, 84),
|
||||
nn.ReLU(),
|
||||
nn.Linear(84, 10),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_relu_stack(self.flatten(x))
|
||||
x = self.conv_relu_stack(x)
|
||||
x = x.view(-1, 16 * 5 * 5)
|
||||
return self.linear_relu_stack(x)
|
||||
|
||||
|
||||
def train(dataloader, model, loss_fn, optimizer):
|
||||
|
@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn):
|
|||
def training():
|
||||
train_data, test_data = get_data()
|
||||
|
||||
stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes))
|
||||
model = NeuralNetwork(stack).to(device)
|
||||
model = NeuralNetwork().to(device)
|
||||
if isfile("model.pth"):
|
||||
print("Loading model from save")
|
||||
model.load_state_dict(torch.load("model.pth"))
|
||||
|
|
Reference in a new issue