45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from torch import no_grad, max
|
|
from torch.utils.data import DataLoader
|
|
|
|
from dataset import get_data
|
|
|
|
|
|
def test(test_loader, net, criterion, device):
|
|
val_loss = 0.0
|
|
val_steps = 0
|
|
total = 0
|
|
correct = 0
|
|
for i, data in enumerate(test_loader, 0):
|
|
with no_grad():
|
|
inputs, labels = data
|
|
inputs, labels = inputs.to(device), labels.to(device)
|
|
|
|
outputs = net(inputs)
|
|
_, predicted = max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
loss = criterion(outputs, labels)
|
|
val_loss += loss.cpu().numpy()
|
|
val_steps += 1
|
|
return val_loss / val_steps, correct / total
|
|
|
|
|
|
def test_accuracy(net, data_root, device):
|
|
train_set, test_set = get_data(data_root)
|
|
|
|
test_loader = DataLoader(
|
|
test_set, batch_size=4, shuffle=False, num_workers=2)
|
|
|
|
correct = 0
|
|
total = 0
|
|
with no_grad():
|
|
for data in test_loader:
|
|
images, labels = data
|
|
images, labels = images.to(device), labels.to(device)
|
|
outputs = net(images)
|
|
_, predicted = max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
return correct / total
|