diff --git a/main.py b/main.py index 38d807ecb6d43c9b8f4a36f94c3c26fe29d7b4a5..ce3a9f8b9716e4fb4c0fd451b9b4a6f025aeaf89 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ from functools import partial -from os.path import join +from os.path import join, abspath from numpy.random import randint from ray import tune @@ -59,13 +59,9 @@ def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1): best_checkpoint_dir, "checkpoint")) best_trained_model.load_state_dict(model_state) - # If Pytorch don't save the end - print("In case saving...") - save(best_trained_model, "/home/flifloo/IA/model.pth") - print("Testing accuracy...") print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}") if __name__ == "__main__": - main("/home/flifloo/IA/data") + main(abspath("data"))