视频链接:DataLoader的使用_哔哩哔哩_bilibili
目录
什么是DataLoader?
把一个数据集比作一副扑克牌,一张扑克牌就是一个数据
把DataLoader比作神经网络的手,手去抓牌。一次抓几张,抓牌有没有顺序,用一只手还是两只手,等等,这都是通过设置DataLoader的参数决定的。
如何使用DataLoader?(以单张图举例)
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
# 在test_set中抓样本,每一批抓4个样本,随机抓,
# 只有主进程去加载batch数据,
# 丢弃因数据集样本数不能被batch_size整除而产生的最后一个不完整的mini_batch
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 在CIFAR10的官方文档中的__getitem__可以看到它返回的先是img后是target
# 测试数据集中第一张图片及其target
img, target = test_data[0]
print(img.shape)
print(target)
运行结果:
单个图片时,[3, 32, 32]指的是3通道、图片尺寸32×32
多张图时
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
运行结果:
torch.Size([4, 3, 32, 32])中的4是指batch_size,4张图片
tensor([7, 9, 9, 4])是指4张图片的target打包在一起,分别是7, 9, 9, 4
如何使用DataLoader?(以多张图举例)
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
"""
上一行的drop_last的功能是,当最后一步显示图片不足64张图片的时候,
是否舍弃最后一步,True就是舍去,False不舍,但最后一步显示的图片可能会不足64
"""
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("test_data2", imgs, step)
step += 1
writer.close()
在Terminal打开tensorboard:
shuffle作用示例
shuffle表示是否随机“抓牌”
#shuffle的作用
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
writer = SummaryWriter("dataloader")
forepochinrange(2):#epoch取0或1
step = 0
fordataintest_loader:
imgs, targets = data
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step += 1
writer.close()