今天学习DataLoader的使用,DataLoader主要是对数据集进行封装,批量读取数据。
# 导入所需库
import torchvision # torchvision库提供常用的数据集、模型、变换等
from torch.utils.data import DataLoader # DataLoader提供对数据集的封装,批量读取数据
from torch.utils.tensorboard import SummaryWriter # tensorboard可视化工具,用于展示训练过程
# 加载CIFAR-10测试数据集,设置transform响应式变换为tensor类型,使图片可以直接输入模型
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# 使用DataLoader将测试数据集转化为batch_size规定大小的批量数据,
# shuffle参数表示每次从数据集中随机取出部分样本放入一个batch,num_worker表示读取数据的线程数,drop_last表示是否舍去最后一个大小不足batch_size的批量数据
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 打印测试数据集中第一张图片的形状以及target
img, target = test_data[0]
print(img.shape)
print(target)
# SummaryWriter用于保存训练过程中所得数据,并以图形化界面展示,用于监控训练过程
writer = SummaryWriter("logs/dataloader")
# 循环遍历测试数据集两次
for epoch in range(2):
step = 0 # 初始化当前批量数据的id
for data in test_loader: # 遍历测试数据集所有批量数据
img, targets = data # 获取当前批量数据中的图片租以及标签
writer.add_images("Epoch_{}".format(epoch), img, step) # 将当前批量数据的图片数据添加到SummaryWriter中并命名为Epoch_{epoch}
step = step + 1 # 记录当前批量数据结束,批量数据id加1
writer.close() # 关闭SummaryWrite,结束训练并且保存数据
在Pytorch Terminal中输入:tensorboard --logdir=logs/dataloader,即可得图像