diff --git a/main.py b/main.py
index a52bdb177be108922f4e25ac0137c13bd81922bc..cc9f07327c1df504746f7293480fe22181db9343 100644
--- a/main.py
+++ b/main.py
@@ -1,7 +1,6 @@
 from os.path import isfile
 
 import torch
-from numpy import prod
 from torch import nn
 from torch.utils.data import DataLoader
 from torchvision import datasets
@@ -36,33 +35,31 @@ def get_data(batch_size: int = 64):
     return train_dataloader, test_dataloader
 
 
-def generate_layers(inp: int, output: int):
-    layers = 2
-    conns = (inp+output)*2
-    stack = [nn.Linear(inp, conns), nn.ReLU()]
-
-    print(f"input: {inp}, output: {output}, layers: {layers}, conns: {conns}")
-
-    print("Generating stack...")
-    for _ in range(layers):
-        stack.append(nn.Linear(conns, conns))
-        stack.append(nn.ReLU())
-
-    stack += [nn.Linear(conns, output), nn.ReLU()]
-
-    print("Stack generated")
-    return stack
-
-
 # Define model
 class NeuralNetwork(nn.Module):
-    def __init__(self, stack):
+    def __init__(self):
         super(NeuralNetwork, self).__init__()
-        self.flatten = nn.Flatten()
-        self.linear_relu_stack = nn.Sequential(*stack)
+        self.conv_relu_stack = nn.Sequential(
+            nn.Conv2d(3, 6, (5, 5)),
+            nn.MaxPool2d(2, 2),
+            nn.ReLU(),
+            nn.Conv2d(6, 16, (5, 5)),
+            nn.MaxPool2d(2, 2),
+            nn.ReLU(),
+        )
+        self.linear_relu_stack = nn.Sequential(
+            nn.Linear(16*(5**2), 120),
+            nn.ReLU(),
+            nn.Linear(120, 84),
+            nn.ReLU(),
+            nn.Linear(84, 10),
+            nn.ReLU(),
+        )
 
     def forward(self, x):
-        return self.linear_relu_stack(self.flatten(x))
+        x = self.conv_relu_stack(x)
+        x = x.view(-1, 16 * 5 * 5)
+        return self.linear_relu_stack(x)
 
 
 def train(dataloader, model, loss_fn, optimizer):
@@ -103,8 +100,7 @@ def test(dataloader, model, loss_fn):
 def training():
     train_data, test_data = get_data()
 
-    stack = generate_layers(prod(test_data.dataset.data[0].shape), len(test_data.dataset.classes))
-    model = NeuralNetwork(stack).to(device)
+    model = NeuralNetwork().to(device)
     if isfile("model.pth"):
         print("Loading model from save")
         model.load_state_dict(torch.load("model.pth"))