引包
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
图像预处理
transform = transforms.Compose([
transforms.Scale(40),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor()])
准备数据
train_dataset = dsets.CIFAR10(root='./data/',
train=True,
transform=transform,
download=True)
image, label = train_dataset[0]
print (image.size())
print (label)
Files already downloaded and verified
torch.Size([3, 32, 32])
6
加载数据
train_loader = data.DataLoader(dataset=train_dataset,
batch_size=100,
shuffle=True,
num_workers=2)
data_iter = iter(train_loader)
images, labels = next(data_iter)
for images, labels in train_loader:
pass