1
0
Fork 0
This repository has been archived on 2024-02-17. You can view files and clone it, but cannot push or open issues or pull requests.
TP_IA/training.py

61 lines
2 KiB
Python
Raw Normal View History

from os.path import join
from ray import tune
from torch import save, load, nn
from torch.optim import SGD
from NeuralNetwork import NeuralNetwork
from dataset import load_data
from tests import test
def train(train_loader, net, optimizer, criterion, epoch, device):
running_loss = 0.0
epoch_steps = 0
for i, data in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
epoch_steps += 1
if i % 2000 == 1999: # print every 2000 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
running_loss / epoch_steps))
running_loss = 0.0
def training(config, data_root, device="cpu", checkpoint_dir=None):
net = NeuralNetwork(config["l1"], config["l2"]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = SGD(net.parameters(), lr=config["lr"], momentum=0.9)
if checkpoint_dir:
model_state, optimizer_state = load(
join(checkpoint_dir, "checkpoint"))
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
train_loader, test_loader = load_data(config, data_root)
for epoch in range(10):
train(train_loader, net, optimizer, criterion, epoch, device)
loss, accuracy = test(test_loader, net, criterion, device)
with tune.checkpoint_dir(epoch) as checkpoint_dir:
path = join(checkpoint_dir, "checkpoint")
save((net.state_dict(), optimizer.state_dict()), path)
tune.report(loss=loss, accuracy=accuracy)
print("Finished Training")