| import torch |
| from torchvision import datasets |
| import torchvision.transforms as transforms |
|
|
| batch_size = 128 |
|
|
| def data_transform(): |
| transform_train = transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.RandomRotation(10), |
| transforms.RandomCrop(32, padding=4), |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
|
|
| transform_test = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
|
|
| return transform_train, transform_test |
|
|
| def data_loader(transform_train, transform_test): |
| train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) |
| test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) |
|
|
| train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2) |
| test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2) |
| return train_loader, test_loader |
|
|