- 首先,查看Pytorch官网的帮助文档
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
其参数:
其中常用的参数有,dataset
为要使用的数据集;batch_size
为一次性要加载的数据个数;shuffle
为是否打乱数据,True
为打乱,False
为不打乱;num_workers
我们加载数据为多进程还是单进程,如果是单进程就写0
,如果是多进程就写>=1
;在windows
下如果是写多进程可能会报错,可以直接写成0,在Linux
下如果有多进程则可以写多进程;drop_last
为总共的数据除以batch_size
是否希望有余数,若不希望有余数则True
,若希望有余数则False
。
- 其使用
简单粗暴上代码:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
# 创建SummaryWriter模板
writer = SummaryWriter("logs")
# 创建transforms.ToTensor模板
tran_tensor = transforms.ToTensor()
# 创建 torchvision.datasets.CIFAR10
# root为datasets.CIFAR10的目录,train为是否是训练集;
# transforms做数据增强的方法,download数据集是否重新下载;
test_set = torchvision.datasets.CIFAR10(root ="./dataset",train=False,transforms=tran_tensor,download=True)
# 创建DataLoader
# datasets加载数据集(test_set),batch_size一次性加载数据的个数;
# shuffle是否洗牌,True洗牌,False不洗牌;
# num_workers多线程,0为单线程,>=1为多线程(windows下可能会报错);
# drop_last最后剩余的数据集(总数据集个数除以一次性加载的数据数)是否舍取,False不舍取,True舍取;
test_loader = DataLoader(dataset=test_set,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
# 测试shuffle writer.add_images,注意这里使用的是add_images,比以前多了s
# 输入了两个epoch看看里面的图片是不是一样的,如果不一样则证明shuffle=True洗牌成功!
for epoch in range(2):
step = 0
for data in test_loader:
imgs,targets = data
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step = step+1
writer.close()
run
之后,输入命令行:
tensorboard --logdir=logs
两个epoch,一样的step,里面的数据是不一样的,则证明shuffle=True成功!
上一章 6.初识Pytorch之torchvision中的数据集使用
下一章 8.初识Pytorch之nn.Module神经网络基本架构的使用