Change data path to a relative path
This commit is contained in:
parent
95c2b25268
commit
3e2c776b3e
1 changed files with 2 additions and 6 deletions
8
main.py
8
main.py
|
@ -1,5 +1,5 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from os.path import join
|
from os.path import join, abspath
|
||||||
|
|
||||||
from numpy.random import randint
|
from numpy.random import randint
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -59,13 +59,9 @@ def main(data_root, num_samples=10, max_num_epochs=10, gpus_per_trial=1):
|
||||||
best_checkpoint_dir, "checkpoint"))
|
best_checkpoint_dir, "checkpoint"))
|
||||||
best_trained_model.load_state_dict(model_state)
|
best_trained_model.load_state_dict(model_state)
|
||||||
|
|
||||||
# If Pytorch don't save the end
|
|
||||||
print("In case saving...")
|
|
||||||
save(best_trained_model, "/home/flifloo/IA/model.pth")
|
|
||||||
|
|
||||||
print("Testing accuracy...")
|
print("Testing accuracy...")
|
||||||
print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}")
|
print(f"Best trial test set accuracy: {test_accuracy(best_trained_model, data_root, device)}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main("/home/flifloo/IA/data")
|
main(abspath("data"))
|
||||||
|
|
Reference in a new issue