60 lines
2 KiB
Python
60 lines
2 KiB
Python
from os.path import join
|
|
|
|
from ray import tune
|
|
from torch import save, load, nn
|
|
from torch.optim import SGD
|
|
|
|
from NeuralNetwork import NeuralNetwork
|
|
from dataset import load_data
|
|
from tests import test
|
|
|
|
|
|
def train(train_loader, net, optimizer, criterion, epoch, device):
|
|
running_loss = 0.0
|
|
epoch_steps = 0
|
|
for i, data in enumerate(train_loader, 0):
|
|
# get the inputs; data is a list of [inputs, labels]
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
# zero the parameter gradients
|
|
optimizer.zero_grad()
|
|
|
|
# forward + backward + optimize
|
|
outputs = net(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# print statistics
|
|
running_loss += loss.item()
|
|
epoch_steps += 1
|
|
if i % 2000 == 1999: # print every 2000 mini-batches
|
|
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
|
|
running_loss / epoch_steps))
|
|
running_loss = 0.0
|
|
|
|
|
|
def training(config, data_root, device="cpu", checkpoint_dir=None):
|
|
net = NeuralNetwork(config["l1"], config["l2"]).to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = SGD(net.parameters(), lr=config["lr"], momentum=0.9)
|
|
|
|
if checkpoint_dir:
|
|
model_state, optimizer_state = load(
|
|
join(checkpoint_dir, "checkpoint"))
|
|
net.load_state_dict(model_state)
|
|
optimizer.load_state_dict(optimizer_state)
|
|
|
|
train_loader, test_loader = load_data(config, data_root)
|
|
|
|
for epoch in range(10):
|
|
train(train_loader, net, optimizer, criterion, epoch, device)
|
|
loss, accuracy = test(test_loader, net, criterion, device)
|
|
|
|
with tune.checkpoint_dir(epoch) as checkpoint_dir:
|
|
path = join(checkpoint_dir, "checkpoint")
|
|
save((net.state_dict(), optimizer.state_dict()), path)
|
|
tune.report(loss=loss, accuracy=accuracy)
|
|
print("Finished Training")
|