Skip to content
Snippets Groups Projects
Commit b9100868 authored by CHARLAIX FLORIAN p1905458's avatar CHARLAIX FLORIAN p1905458
Browse files

Second attempt of nn with convolution

parent 3933b77d
No related branches found
No related tags found
No related merge requests found
from os.path import isfile
import torch
from numpy import prod
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
......@@ -36,33 +35,31 @@ def get_data(batch_size: int = 64):
return train_dataloader, test_dataloader
def generate_layers(inp: int, output: int):
layers = 2
conns = (inp+output)*2
stack = [nn.Linear(inp, conns), nn.ReLU()]
print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}")
print("Generating stack...")
for _ in range(layers):
stack.append(nn.Linear(conns, conns))
stack.append(nn.ReLU())
stack += [nn.Linear(conns, output), nn.ReLU()]
print("Stack generated")
return stack
# Define model
class NeuralNetwork(nn.Module):
def __init__(self, stack):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(*stack)
self.conv_relu_stack = nn.Sequential(
nn.Conv2d(3, 6, (5, 5)),
nn.MaxPool2d(2, 2),
nn.ReLU(),
nn.Conv2d(6, 16, (5, 5)),
nn.MaxPool2d(2, 2),
nn.ReLU(),
)
self.linear_relu_stack = nn.Sequential(
nn.Linear(16*(5**2), 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
nn.ReLU(),
)
def forward(self, x):
return self.linear_relu_stack(self.flatten(x))
x = self.conv_relu_stack(x)
x = x.view(-1, 16 * 5 * 5)
return self.linear_relu_stack(x)
def train(dataloader, model, loss_fn, optimizer):
......@@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn):
def training():
train_data, test_data = get_data()
stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes))
model = NeuralNetwork(stack).to(device)
model = NeuralNetwork().to(device)
if isfile("model.pth"):
print("Loading model from save")
model.load_state_dict(torch.load("model.pth"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment