import torch
import torchvision
import torch.nn as nn
import numpy as np
import torchvision.transforms as transforms
# 下载CIFAR-10数据集到当前data文件夹中
train_dataset = torchvision.datasets.CIFAR10(root='data/',
train=True,
transform=transforms.ToTensor(),
download=True)
# 从本地硬盘上读取一条数据 (包括1张图像及其对应的标签)
image, label = train_dataset[0]
print(image.size()) #输出 torch.Size([3, 32, 32])
print(label) #输出 6
# 数据加载准备 (开启数据加载的线程和队列).
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=64, #该参数表示每次读取的批样本个数
shuffle=True) #该参数表示读取时是否打乱样本顺序
# 创建迭代器
data_iter = iter(train_loader)
# 当迭代开始时, 队列和线程开始读取数据
images, labels = data_iter.next()
print(images.size()) #输出 torch.Size([64, 3, 32, 32])
print(labels.size()) #输出 torch.Size([64])
# 实际使用时使用下面的方式读取每一批(batch)样本
for images, labels in train_loader:
# 在此处添加训练代码
pass