from functools import partial from os.path import join from numpy.random import randint from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler from torch import nn, load, save from torch.cuda import is_available from NeuralNetwork import NeuralNetwork from dataset import get_data from tests import test_accuracy from training import training device = "cuda:0" if is_available() else "cpu" print(f"Using {device} device") def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1): get_data(data_root, True) config = { "l1": tune.sample_from(lambda _: 2 ** randint(2, 9)), "l2": tune.sample_from(lambda _: 2 ** randint(2, 9)), "lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([2, 4, 8, 16]) } scheduler = ASHAScheduler( metric="loss", mode="min", max_t=max_num_epochs, grace_period=1, reduction_factor=2) reporter = CLIReporter( # parameter_columns=["l1", "l2", "lr", "batch_size"], metric_columns=["loss", "accuracy", "training_iteration"]) result = tune.run( partial(training, data_root=data_root, device=device), resources_per_trial={"cpu": 2, "gpu": gpus_per_trial}, config=config, num_samples=num_samples, scheduler=scheduler, progress_reporter=reporter) best_trial = result.get_best_trial("loss", "min", "last") print(f"Best trial config: {best_trial.config}") print(f"Best trial final validation loss: {best_trial.last_result['loss']}") print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}") best_trained_model = NeuralNetwork(best_trial.config["l1"], best_trial.config["l2"]) if is_available(): if gpus_per_trial > 1: best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) best_checkpoint_dir = best_trial.checkpoint.value model_state, optimizer_state = load(join( best_checkpoint_dir, "checkpoint")) best_trained_model.load_state_dict(model_state) # If Pytorch don't save the end print("In case saving...") save(best_trained_model, "/home/flifloo/IA/model.pth") print("Testing accuracy...") print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}") if __name__ == "__main__": main("/home/flifloo/IA/data")