import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
1.作用
用于对数据进行批量加载和处理。它提供了一种便捷的方式来遍历数据集,尤其在处理大数据集时非常有用。
批量加载数据:DataLoader 将数据集分成小批次(batch),可以通过设置 batch_size 参数来指定每个批次的大小。例如,如果你设置 batch_size=4,那么 DataLoader 每次会返回 4 个数据样本。
打乱数据:shuffle=True 参数使数据在每个 epoch(训练轮次)开始时被打乱。这对于训练神经网络非常重要,因为它可以帮助防止模型过拟合到数据的特定顺序。
并行数据加载:num_workers 参数允许你使用多个子进程来加速数据加载。如果 num_workers=0,数据将在主进程中加载。通过增加 num_workers,可以提高数据加载的效率。
处理剩余数据:drop_last 参数决定是否丢弃最后一个不满批次大小的数据。如果设置为 True,最后一个批次的数据样本数不足 batch_size 时会被丢弃。如果设置为 False,会返回这个不满批次大小的数据。
数据变换:可以通过 transform 参数对数据进行预处理或增强,如数据归一化、数据增强等。
# 准备的测试集,transform=torchvision.transforms.ToTensor():将图像转换为 Tensor 格式。
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# batch_size是将多少个图片打成一个包,shuffle是每一轮结束后要不要洗牌,drop_last是把不满足最后一包的零碎图片忽略
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 测试数据集第一张图片
img, target = test_data[0]
print(img.shape)
print(target)
# torch.Size([3, 32, 32]) 3
2.遍历测试集并记录图像
# 初始化summarywriter,并将结果写在dataload里
writer = SummaryWriter('dataload')
# 步数初始化
step = 0
# 遍历测试集
for data in test_loader:
imgs, targets = data
print(imgs.shape) # 打印图像的size
print(targets) # 打印标签labels
writer.add_images('test_data', imgs,step) # 写入tensorboard,命名为test_data
step+=1
# writer.close()
# 加上epoch后遍历测试集
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1
writer.close()
单次遍历主要用于评估模型的性能,或者简单地查看数据样本。
多轮使用epoch遍历通常用于训练模型。在这种情况下,模型在每个 epoch 中遍历整个数据集,以不断调整和优化模型的参数。每个 epoch 结束后,数据通常会被打乱(如果 shuffle=True),然后重新开始遍历数据集。