DataLoader 官方文件查看方法:pytorch官网 -> document -> pytorch -> 搜索 dataloader (搜不到直接左边找torch.utils.data)
import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset_2",train=False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=True)
# 会返回一个img元组,和一个target元组
# 测试数据集中的第一张图片及target
img,target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs,targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch: {}".format(epoch),imgs,step)
step = step + 1
writer.close()
Dataloader 代码的解释:
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=True)
batch_size=64 # 每次抓取64张图片
shuffle=False # False 每次按顺序抓取; True 每次随即抓取
num_workers=0 # 指定数据加载时使用的进程数(workers的数量)
drop_last=True # 在最后的一组图片数量不够时,False 表示不舍去这一组; True 表示舍去
Dataloader 代码的返回值:
for images, labels in test_loader:
print(images.shape) # 输出: torch.Size([64, 3, 32, 32])
print(labels.shape) # 输出: torch.Size([64])
# 注意:这里只是展示了如何访问数据和标签的形状,并没有进行任何实际的数据处理或模型评估。
break # 只打印第一个批次的数据,防止无限循环
当你迭代test_loader
时,每次迭代返回的值是一个元组(tuple),其中包含两个元素:
-
images:一个形状为
[batch_size, channels, height, width]
的张量(Tensor),其中batch_size
是每批的样本数(这里是64),channels
是通道数(对于CIFAR-10,它是3,代表RGB),height
和width
是图像的高度和宽度(对于CIFAR-10,都是32)。 -
labels:一个形状为
[batch_size]
的张量,包含每个样本的类别标签。这些标签是整数,范围从0到9,代表CIFAR-10数据集中的10个类别。
如何加入一个任意值:
"Epoch: {}".format(epoch)
这里的 format(file_name)
方法调用会将 file_name
变量的值插入到字符串 "{}.txt"
中的 {}
占位符位置。假设 file_name
的值是 "example"
,那么结果字符串将是 "example.txt"
f"{file_name}.txt"
这会在执行时自动将 file_name
的值插入到字符串中,无需显式调用 format()
方法