from torch import no_grad, max from torch.utils.data import DataLoader from dataset import get_data def test(test_loader, net, criterion, device): val_loss = 0.0 val_steps = 0 total = 0 correct = 0 for i, data in enumerate(test_loader, 0): with no_grad(): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = net(inputs) _, predicted = max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() loss = criterion(outputs, labels) val_loss += loss.cpu().numpy() val_steps += 1 return val_loss / val_steps, correct / total def test_accuracy(net, data_root, device): train_set, test_set = get_data(data_root) test_loader = DataLoader( test_set, batch_size=4, shuffle=False, num_workers=2) correct = 0 total = 0 with no_grad(): for data in test_loader: images, labels = data images, labels = images.to(device), labels.to(device) outputs = net(images) _, predicted = max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() return correct / total