1
0
Fork 0
This repository has been archived on 2024-02-17. You can view files and clone it, but cannot push or open issues or pull requests.
TP_IA/main.py

68 lines
2.3 KiB
Python

from functools import partial
from os.path import join, abspath
from numpy.random import randint
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from torch import nn, load, save
from torch.cuda import is_available
from NeuralNetwork import NeuralNetwork
from dataset import get_data
from tests import test_accuracy
from training import training
device = "cuda:0" if is_available() else "cpu"
print(f"Using {device} device")
def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1):
get_data(data_root, True)
config = {
"l1": tune.sample_from(lambda _: 2 ** randint(2, 9)),
"l2": tune.sample_from(lambda _: 2 ** randint(2, 9)),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([2, 4, 8, 16])
}
scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=max_num_epochs,
grace_period=1,
reduction_factor=2)
reporter = CLIReporter(
# parameter_columns=["l1", "l2", "lr", "batch_size"],
metric_columns=["loss", "accuracy", "training_iteration"])
result = tune.run(
partial(training, data_root=data_root, device=device),
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter)
best_trial = result.get_best_trial("loss", "min", "last")
print(f"Best trial config: {best_trial.config}")
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")
best_trained_model = NeuralNetwork(best_trial.config["l1"], best_trial.config["l2"])
if is_available():
if gpus_per_trial > 1:
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)
best_checkpoint_dir = best_trial.checkpoint.value
model_state, optimizer_state = load(join(
best_checkpoint_dir, "checkpoint"))
best_trained_model.load_state_dict(model_state)
print("Testing accuracy...")
print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}")
if __name__ == "__main__":
main(abspath("data"))