diff --git a/NeuralNetwork.py b/NeuralNetwork.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ab62062cf713a15124ca3ff19b64ee79fb9551
--- /dev/null
+++ b/NeuralNetwork.py
@@ -0,0 +1,26 @@
+from torch import nn
+
+
+class NeuralNetwork(nn.Module):
+    def __init__(self, l1=120, l2=84):
+        super(NeuralNetwork, self).__init__()
+        self.conv_relu_stack = nn.Sequential(
+            nn.Conv2d(3, 6, (5, 5)),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2),
+            nn.Conv2d(6, 16, (5, 5)),
+            nn.ReLU(),
+            nn.MaxPool2d(2, 2)
+        )
+        self.linear_relu_stack = nn.Sequential(
+            nn.Linear(16*(5**2), l1),
+            nn.ReLU(),
+            nn.Linear(l1, l2),
+            nn.ReLU(),
+            nn.Linear(l2, 10),
+        )
+
+    def forward(self, x):
+        x = self.conv_relu_stack(x)
+        x = x.view(-1, 16 * (5 ** 2))
+        return self.linear_relu_stack(x)
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..451e3912345162eaa8641cf2e34957881afb6350
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,44 @@
+from torch.utils.data import random_split, DataLoader
+from torchvision import datasets
+from torchvision.transforms import ToTensor
+
+
+def get_data(data_root, download=False):
+    transform = ToTensor()
+    # Download training data from open datasets.
+    training_data = datasets.CIFAR10(
+        root=data_root,
+        train=True,
+        download=download,
+        transform=transform,
+    )
+
+    # Download test data from open datasets.
+    testing_data = datasets.CIFAR10(
+        root=data_root,
+        train=False,
+        download=download,
+        transform=transform,
+    )
+
+    return training_data, testing_data
+
+
+def load_data(config, data_root):
+    train_set, test_set = get_data(data_root)
+
+    test_abs = int(len(train_set) * 0.8)
+    train_subset, test_subset = random_split(
+        train_set, [test_abs, len(train_set) - test_abs])
+
+    train_loader = DataLoader(
+        train_subset,
+        batch_size=int(config["batch_size"]),
+        shuffle=True,
+        num_workers=2)
+    test_loader = DataLoader(
+        test_subset,
+        batch_size=int(config["batch_size"]),
+        shuffle=True,
+        num_workers=2)
+    return train_loader, test_loader
diff --git a/main.py b/main.py
index cc9f07327c1df504746f7293480fe22181db9343..38d807ecb6d43c9b8f4a36f94c3c26fe29d7b4a5 100644
--- a/main.py
+++ b/main.py
@@ -1,126 +1,71 @@
-from os.path import isfile
-
-import torch
-from torch import nn
-from torch.utils.data import DataLoader
-from torchvision import datasets
-from torchvision.transforms import ToTensor
-
-
-device = "cuda" if torch.cuda.is_available() else "cpu"
+from functools import partial
+from os.path import join
+
+from numpy.random import randint
+from ray import tune
+from ray.tune import CLIReporter
+from ray.tune.schedulers import ASHAScheduler
+from torch import nn, load, save
+from torch.cuda import is_available
+
+from NeuralNetwork import NeuralNetwork
+from dataset import get_data
+from tests import test_accuracy
+from training import training
+
+device = "cuda:0" if is_available() else "cpu"
 print(f"Using {device} device")
 
 
-def get_data(batch_size: int = 64):
-    # Download training data from open datasets.
-    training_data = datasets.CIFAR10(
-        root="/home/flifloo/IA/data",
-        train=True,
-        download=True,
-        transform=ToTensor(),
-    )
-
-    # Download test data from open datasets.
-    testing_data = datasets.CIFAR10(
-        root="/home/flifloo/IA/data",
-        train=False,
-        download=True,
-        transform=ToTensor(),
-    )
-
-    # Create data loaders.
-    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
-    test_dataloader = DataLoader(testing_data, batch_size=batch_size, shuffle=True)
-
-    return train_dataloader, test_dataloader
-
-
-# Define model
-class NeuralNetwork(nn.Module):
-    def __init__(self):
-        super(NeuralNetwork, self).__init__()
-        self.conv_relu_stack = nn.Sequential(
-            nn.Conv2d(3, 6, (5, 5)),
-            nn.MaxPool2d(2, 2),
-            nn.ReLU(),
-            nn.Conv2d(6, 16, (5, 5)),
-            nn.MaxPool2d(2, 2),
-            nn.ReLU(),
-        )
-        self.linear_relu_stack = nn.Sequential(
-            nn.Linear(16*(5**2), 120),
-            nn.ReLU(),
-            nn.Linear(120, 84),
-            nn.ReLU(),
-            nn.Linear(84, 10),
-            nn.ReLU(),
-        )
-
-    def forward(self, x):
-        x = self.conv_relu_stack(x)
-        x = x.view(-1, 16 * 5 * 5)
-        return self.linear_relu_stack(x)
-
-
-def train(dataloader, model, loss_fn, optimizer):
-    size = len(dataloader.dataset)
-    for batch, (X, y) in enumerate(dataloader):
-        X, y = X.to(device), y.to(device)
-
-        # Compute prediction error
-        pred = model(X)
-        loss = loss_fn(pred, y)
-
-        # Backpropagation
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-
-        if batch % 100 == 0:
-            loss, current = loss.item(), batch * len(X)
-            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
-
-
-def test(dataloader, model, loss_fn):
-    size = len(dataloader.dataset)
-    model.eval()
-    test_loss, correct = 0, 0
-    with torch.no_grad():
-        for X, y in dataloader:
-            X, y = X.to(device), y.to(device)
-            pred = model(X)
-            test_loss += loss_fn(pred, y).item()
-            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
-    test_loss /= size
-    correct /= size
-    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
-    return correct
-
-
-def training():
-    train_data, test_data = get_data()
-
-    model = NeuralNetwork().to(device)
-    if isfile("model.pth"):
-        print("Loading model from save")
-        model.load_state_dict(torch.load("model.pth"))
-
-    print(model)
-
-    loss_fn = nn.CrossEntropyLoss()
-    # lr = sur/sous appretisage
-    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.9)
-
-    e = 0
-    c = 0
-    while c < 0.90:
-        print(f"Epoch {e+1}\n-------------------------------")
-        train(train_data, model, loss_fn, optimizer)
-        c = test(test_data, model, loss_fn)
-        torch.save(model.state_dict(), "model.pth")
-        e += 1
-    print("Done!")
-
-
-if __name__ == '__main__':
-    training()
+def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1):
+    get_data(data_root, True)
+
+    config = {
+        "l1": tune.sample_from(lambda _: 2 ** randint(2, 9)),
+        "l2": tune.sample_from(lambda _: 2 ** randint(2, 9)),
+        "lr": tune.loguniform(1e-4, 1e-1),
+        "batch_size": tune.choice([2, 4, 8, 16])
+    }
+    scheduler = ASHAScheduler(
+        metric="loss",
+        mode="min",
+        max_t=max_num_epochs,
+        grace_period=1,
+        reduction_factor=2)
+    reporter = CLIReporter(
+        # parameter_columns=["l1", "l2", "lr", "batch_size"],
+        metric_columns=["loss", "accuracy", "training_iteration"])
+    result = tune.run(
+        partial(training, data_root=data_root, device=device),
+        resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
+        config=config,
+        num_samples=num_samples,
+        scheduler=scheduler,
+        progress_reporter=reporter)
+
+    best_trial = result.get_best_trial("loss", "min", "last")
+    print(f"Best trial config: {best_trial.config}")
+    print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
+    print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")
+
+    best_trained_model = NeuralNetwork(best_trial.config["l1"], best_trial.config["l2"])
+    if is_available():
+        if gpus_per_trial > 1:
+            best_trained_model = nn.DataParallel(best_trained_model)
+    best_trained_model.to(device)
+
+    best_checkpoint_dir = best_trial.checkpoint.value
+    model_state, optimizer_state = load(join(
+        best_checkpoint_dir, "checkpoint"))
+    best_trained_model.load_state_dict(model_state)
+
+    # If Pytorch don't save the end
+    print("In case saving...")
+    save(best_trained_model, "/home/flifloo/IA/model.pth")
+
+    print("Testing accuracy...")
+    print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}")
+
+
+if __name__ == "__main__":
+    main("/home/flifloo/IA/data")
diff --git a/tests.py b/tests.py
new file mode 100644
index 0000000000000000000000000000000000000000..94b328d1b1f67fdde0f57f9842505a5c7a3f8b6b
--- /dev/null
+++ b/tests.py
@@ -0,0 +1,45 @@
+from torch import no_grad, max
+from torch.utils.data import DataLoader
+
+from dataset import get_data
+
+
+def test(test_loader, net, criterion, device):
+    val_loss = 0.0
+    val_steps = 0
+    total = 0
+    correct = 0
+    for i, data in enumerate(test_loader, 0):
+        with no_grad():
+            inputs, labels = data
+            inputs, labels = inputs.to(device), labels.to(device)
+
+            outputs = net(inputs)
+            _, predicted = max(outputs.data, 1)
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+            loss = criterion(outputs, labels)
+            val_loss += loss.cpu().numpy()
+            val_steps += 1
+    return val_loss / val_steps, correct / total
+
+
+def test_accuracy(net, data_root, device):
+    train_set, test_set = get_data(data_root)
+
+    test_loader = DataLoader(
+        test_set, batch_size=4, shuffle=False, num_workers=2)
+
+    correct = 0
+    total = 0
+    with no_grad():
+        for data in test_loader:
+            images, labels = data
+            images, labels = images.to(device), labels.to(device)
+            outputs = net(images)
+            _, predicted = max(outputs.data, 1)
+            total += labels.size(0)
+            correct += (predicted == labels).sum().item()
+
+    return correct / total
diff --git a/training.py b/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5b051486d6861b7356a0959f5f94e8b6526452c
--- /dev/null
+++ b/training.py
@@ -0,0 +1,60 @@
+from os.path import join
+
+from ray import tune
+from torch import save, load, nn
+from torch.optim import SGD
+
+from NeuralNetwork import NeuralNetwork
+from dataset import load_data
+from tests import test
+
+
+def train(train_loader, net, optimizer, criterion, epoch, device):
+    running_loss = 0.0
+    epoch_steps = 0
+    for i, data in enumerate(train_loader, 0):
+        # get the inputs; data is a list of [inputs, labels]
+        inputs, labels = data
+        inputs, labels = inputs.to(device), labels.to(device)
+
+        # zero the parameter gradients
+        optimizer.zero_grad()
+
+        # forward + backward + optimize
+        outputs = net(inputs)
+        loss = criterion(outputs, labels)
+        loss.backward()
+        optimizer.step()
+
+        # print statistics
+        running_loss += loss.item()
+        epoch_steps += 1
+        if i % 2000 == 1999:  # print every 2000 mini-batches
+            print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
+                                            running_loss / epoch_steps))
+            running_loss = 0.0
+
+
+def training(config, data_root, device="cpu", checkpoint_dir=None):
+    net = NeuralNetwork(config["l1"], config["l2"]).to(device)
+
+    criterion = nn.CrossEntropyLoss()
+    optimizer = SGD(net.parameters(), lr=config["lr"], momentum=0.9)
+
+    if checkpoint_dir:
+        model_state, optimizer_state = load(
+            join(checkpoint_dir, "checkpoint"))
+        net.load_state_dict(model_state)
+        optimizer.load_state_dict(optimizer_state)
+
+    train_loader, test_loader = load_data(config, data_root)
+
+    for epoch in range(10):
+        train(train_loader, net, optimizer, criterion, epoch, device)
+        loss, accuracy = test(test_loader, net, criterion, device)
+
+        with tune.checkpoint_dir(epoch) as checkpoint_dir:
+            path = join(checkpoint_dir, "checkpoint")
+            save((net.state_dict(), optimizer.state_dict()), path)
+        tune.report(loss=loss, accuracy=accuracy)
+    print("Finished Training")