今天学习DataLoader的使用,DataLoader主要是对数据集进行封装,批量读取数据。
# 导入所需库
import torchvision # torchvision库提供常用的数据集、模型、变换等
from torch.utils.data import DataLoader # DataLoader提供对数据集的封装,批量读取数据
from torch.utils.tensorboard import SummaryWriter # tensorboard可视化工具,用于展示训练过程
# 加载CIFAR-10测试数据集,设置transform响应式变换为tensor类型,使图片可以直接输入模型
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# 使用DataLoader将测试数据集转化为batch_size规定大小的批量数据,
# shuffle参数表示每次从数据集中随机取出部分样本放入一个batch,num_worker表示读取数据的线程数,drop_last表示是否舍去最后一个大小不足batch_size的批量数据
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 打印测试数据集中第一张图片的形状以及target
img, target =