from functools import partial from os.path import join, abspath from numpy.random import randint from ray import tune from ray.tune import CLIReporter from ray.tune.schedulers import ASHAScheduler from torch import nn, load, save from torch.cuda import is_available from NeuralNetwork import NeuralNetwork from dataset import get_data from tests import test_accuracy from training import training device = "cuda:0" if is_available() else "cpu" print(f"Using {device} device") def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1): get_data(data_root, True) config = { "l1": tune.sample_from(lambda _: 2 ** randint(2, 9)), "l2": tune.sample_from(lambda _: 2 ** randint(2, 9)), "lr": tune.loguniform(1e-4, 1e-1), "batch_size": tune.choice([2, 4, 8, 16]) } scheduler = ASHAScheduler( metric="loss", mode="min", max_t=max_num_epochs, grace_period=1, reduction_factor=2) reporter = CLIReporter( # parameter_columns=["l1", "l2", "lr", "batch_size"], metric_columns=["loss", "accuracy", "training_iteration"]) result = tune.run( partial(training, data_root=data_root, device=device), resources_per_trial={"cpu": 2, "gpu": gpus_per_trial}, config=config, num_samples=num_samples, scheduler=scheduler, progress_reporter=reporter) best_trial = result.get_best_trial("loss", "min", "last") print(f"Best trial config: {best_trial.config}") print(f"Best trial final validation loss: {best_trial.last_result['loss']}") print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}") best_trained_model = NeuralNetwork(best_trial.config["l1"], best_trial.config["l2"]) if is_available(): if gpus_per_trial > 1: best_trained_model = nn.DataParallel(best_trained_model) best_trained_model.to(device) best_checkpoint_dir = best_trial.checkpoint.value model_state, optimizer_state = load(join( best_checkpoint_dir, "checkpoint")) best_trained_model.load_state_dict(model_state) print("Testing accuracy...") print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}") if __name__ == "__main__": main(abspath("data"))