1
0
Fork 0
This repository has been archived on 2024-02-17. You can view files and clone it, but cannot push or open issues or pull requests.
TP_IA/dataset.py

45 lines
1.1 KiB
Python
Raw Normal View History

from torch.utils.data import random_split, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
def get_data(data_root, download=False):
transform = ToTensor()
# Download training data from open datasets.
training_data = datasets.CIFAR10(
root=data_root,
train=True,
download=download,
transform=transform,
)
# Download test data from open datasets.
testing_data = datasets.CIFAR10(
root=data_root,
train=False,
download=download,
transform=transform,
)
return training_data, testing_data
def load_data(config, data_root):
train_set, test_set = get_data(data_root)
test_abs = int(len(train_set) * 0.8)
train_subset, test_subset = random_split(
train_set, [test_abs, len(train_set) - test_abs])
train_loader = DataLoader(
train_subset,
batch_size=int(config["batch_size"]),
shuffle=True,
num_workers=2)
test_loader = DataLoader(
test_subset,
batch_size=int(config["batch_size"]),
shuffle=True,
num_workers=2)
return train_loader, test_loader