45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
|
from torch.utils.data import random_split, DataLoader
|
||
|
from torchvision import datasets
|
||
|
from torchvision.transforms import ToTensor
|
||
|
|
||
|
|
||
|
def get_data(data_root, download=False):
|
||
|
transform = ToTensor()
|
||
|
# Download training data from open datasets.
|
||
|
training_data = datasets.CIFAR10(
|
||
|
root=data_root,
|
||
|
train=True,
|
||
|
download=download,
|
||
|
transform=transform,
|
||
|
)
|
||
|
|
||
|
# Download test data from open datasets.
|
||
|
testing_data = datasets.CIFAR10(
|
||
|
root=data_root,
|
||
|
train=False,
|
||
|
download=download,
|
||
|
transform=transform,
|
||
|
)
|
||
|
|
||
|
return training_data, testing_data
|
||
|
|
||
|
|
||
|
def load_data(config, data_root):
|
||
|
train_set, test_set = get_data(data_root)
|
||
|
|
||
|
test_abs = int(len(train_set) * 0.8)
|
||
|
train_subset, test_subset = random_split(
|
||
|
train_set, [test_abs, len(train_set) - test_abs])
|
||
|
|
||
|
train_loader = DataLoader(
|
||
|
train_subset,
|
||
|
batch_size=int(config["batch_size"]),
|
||
|
shuffle=True,
|
||
|
num_workers=2)
|
||
|
test_loader = DataLoader(
|
||
|
test_subset,
|
||
|
batch_size=int(config["batch_size"]),
|
||
|
shuffle=True,
|
||
|
num_workers=2)
|
||
|
return train_loader, test_loader
|