67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
from functools import partial
|
|
from os.path import join, abspath
|
|
|
|
from numpy.random import randint
|
|
from ray import tune
|
|
from ray.tune import CLIReporter
|
|
from ray.tune.schedulers import ASHAScheduler
|
|
from torch import nn, load, save
|
|
from torch.cuda import is_available
|
|
|
|
from NeuralNetwork import NeuralNetwork
|
|
from dataset import get_data
|
|
from tests import test_accuracy
|
|
from training import training
|
|
|
|
device = "cuda:0" if is_available() else "cpu"
|
|
print(f"Using {device} device")
|
|
|
|
|
|
def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1):
|
|
get_data(data_root, True)
|
|
|
|
config = {
|
|
"l1": tune.sample_from(lambda _: 2 ** randint(2, 9)),
|
|
"l2": tune.sample_from(lambda _: 2 ** randint(2, 9)),
|
|
"lr": tune.loguniform(1e-4, 1e-1),
|
|
"batch_size": tune.choice([2, 4, 8, 16])
|
|
}
|
|
scheduler = ASHAScheduler(
|
|
metric="loss",
|
|
mode="min",
|
|
max_t=max_num_epochs,
|
|
grace_period=1,
|
|
reduction_factor=2)
|
|
reporter = CLIReporter(
|
|
# parameter_columns=["l1", "l2", "lr", "batch_size"],
|
|
metric_columns=["loss", "accuracy", "training_iteration"])
|
|
result = tune.run(
|
|
partial(training, data_root=data_root, device=device),
|
|
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
|
|
config=config,
|
|
num_samples=num_samples,
|
|
scheduler=scheduler,
|
|
progress_reporter=reporter)
|
|
|
|
best_trial = result.get_best_trial("loss", "min", "last")
|
|
print(f"Best trial config: {best_trial.config}")
|
|
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
|
|
print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")
|
|
|
|
best_trained_model = NeuralNetwork(best_trial.config["l1"], best_trial.config["l2"])
|
|
if is_available():
|
|
if gpus_per_trial > 1:
|
|
best_trained_model = nn.DataParallel(best_trained_model)
|
|
best_trained_model.to(device)
|
|
|
|
best_checkpoint_dir = best_trial.checkpoint.value
|
|
model_state, optimizer_state = load(join(
|
|
best_checkpoint_dir, "checkpoint"))
|
|
best_trained_model.load_state_dict(model_state)
|
|
|
|
print("Testing accuracy...")
|
|
print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main(abspath("data"))
|