趁着假期,准备好好学习PyTorch,以下是DataLoader的学习笔记。在此,特别感谢B站UP主【我是土堆】,土堆老师讲解得非常认真详细,各位pytorch入门学习者可以去看看~
视频链接:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili
从dataset里面取数据,怎么取,取多少,这些就是由DataLoader实现的。
PyTorch官网的DataLoader解释:torch.utils.data — PyTorch 1.10.1 documentation
以下代码实现的功能是,对test_data进行不放回随机抽取,每次抽取64张图片及其标签并打包为imgs和targets,直到把dataset抽取完毕。若最后一次抽取少于64张,则舍去(drop_last=Ture)。这样的抽取进行两次(for epoch in range(2)),每次抽取前先进行洗牌(Shuffle=True)。
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs,targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch:{}".format(epoch),imgs,step)
step = step + 1
writer.close()
Tensorboard展示两次抽取结果: