from torch.utils.data import random_split, DataLoader from torchvision import datasets from torchvision.transforms import ToTensor def get_data(data_root, download=False): transform = ToTensor() # Download training data from open datasets. training_data = datasets.CIFAR10( root=data_root, train=True, download=download, transform=transform, ) # Download test data from open datasets. testing_data = datasets.CIFAR10( root=data_root, train=False, download=download, transform=transform, ) return training_data, testing_data def load_data(config, data_root): train_set, test_set = get_data(data_root) test_abs = int(len(train_set) * 0.8) train_subset, test_subset = random_split( train_set, [test_abs, len(train_set) - test_abs]) train_loader = DataLoader( train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=2) test_loader = DataLoader( test_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=2) return train_loader, test_loader