from os.path import join from ray import tune from torch import save, load, nn from torch.optim import SGD from NeuralNetwork import NeuralNetwork from dataset import load_data from tests import test def train(train_loader, net, optimizer, criterion, epoch, device): running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(train_loader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() epoch_steps += 1 if i % 2000 == 1999: # print every 2000 mini-batches print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps)) running_loss = 0.0 def training(config, data_root, device="cpu", checkpoint_dir=None): net = NeuralNetwork(config["l1"], config["l2"]).to(device) criterion = nn.CrossEntropyLoss() optimizer = SGD(net.parameters(), lr=config["lr"], momentum=0.9) if checkpoint_dir: model_state, optimizer_state = load( join(checkpoint_dir, "checkpoint")) net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) train_loader, test_loader = load_data(config, data_root) for epoch in range(10): train(train_loader, net, optimizer, criterion, epoch, device) loss, accuracy = test(test_loader, net, criterion, device) with tune.checkpoint_dir(epoch) as checkpoint_dir: path = join(checkpoint_dir, "checkpoint") save((net.state_dict(), optimizer.state_dict()), path) tune.report(loss=loss, accuracy=accuracy) print("Finished Training")