Dataset告诉程序数据集在什么位置,DataLoader是一个加载器,从Dataset中取数据加载到神经网络中,取多少怎么取都由DataLoader中的参数控制,官网中的定义如下:
一、每次使用需要引用包
![](https://i-blog.csdnimg.cn/blog_migrate/5ab27e7860272d8b9d9e126ec9720897.png)
二、参数说明
CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False, pin_memory_device='')
其中大多数参数都有默认值,只有 dataset 没有,这个 dataset 需要自定义,附上链接:(38条消息) pytorch 基础认识——Dataset_堂小仙儿的博客-CSDN博客,主要告诉我们数据在什么位置,包括第一第二数据集有多少数据,具体可以看官网里 Dataset 有什么方法和类。
三、常见参数设置
batch_size(int,optional):例如有一摞牌,batch_size=1,每次抓牌抓一张,等于2,每次抓就抓两张;
shuffle(bool,optional):就是洗牌,每一局牌结束就要洗牌,如果和第一次牌的顺序一致,就返回 False,否则返回 True;
sampler (Sampler or Iterable, optional):随机抓取,官网中的定义是这样的:
sampler: Union[Sampler, Iterable, None] = None
一般情况下默认随即抓取
batch_sampler(Sampler or Iterable, optional): batch_sampler是一个批数据采样器,通过某种限制得到输入数据集的子集,即批数据;
num_workers(int, optional):多进程,默认为0,表示只用主进程(注:有时候不为零会出现报错,可以该成零看看能不能解决);
drop_last(bool, optional):50张牌每次抓三张牌,抓到最后还剩一张牌,这张牌是否舍弃,True 表示舍弃,False 表示保留;
举例:
import torchvision
from torch.utils.data import DataLoader
# 准备的测试集
test_data = torchvision.datasets.CIFAR10(root='F:\\pycharm\\text\\dataset',train=False,transform=torchvision.transforms.ToTensor())
# 最后一句 transform 是自定义,表示对数据做处理,都变成tensor类型的
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0,drop_last=False)
# 要处理的数据集是 test_data,每次处理4张图片,处理结束,顺序和处理之前一致,单线程处理,多出来的图片保留
img, target = test_data[0]
print(img.shape) # 输出:torch.Size([3, 32, 32]) 3通道,大小为32*32
print(target) # 输出:3 在数据集内标签为3
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
# for循环说的是,从 test_loader中取出数据,test_loader 实例化 DataLoader 这个类,引用了 dataset ,这个类里面需要重置 __getitem__方法,给的 test_data 就是重置的目标,返回值包括图片和标签,然后图片和标签被 dataloader 按照 batchsize 的大小分开,具体如下所示:
h
![](https://i-blog.csdnimg.cn/blog_migrate/527df17c02555e6f30289ec2dff0b8b8.png)
整个运行结果如下所示(只拿部分举例):
torch.Size([3, 32, 32])
3 #一二两句是单张图片的数据,3通道的32*32大小的图片,在数据集内的标签是3
# 当batchsize是4的时候,就从数据据里一次拿四张,都是3通道32*32的图片,每张图片的标签被打包了,如第五句所示,这里的采样时随机的,因为 sample 时
torch.Size([4, 3, 32, 32])
tensor([5, 0, 1, 4])
torch.Size([4, 3, 32, 32])
tensor([4, 9, 3, 6])