Pytorch Dataset,TensorDataset,Dataloader,Sampler关系

Dataloader

Dataloader是数据加载器,组合数据集和采样器,并在数据集上提供单线程或多线程的迭代器。

所以Dataloader的参数必然需要指定数据集Dataset和采样器Sampler。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
  • dataset (Dataset) – 数据集。
  • batch_size (int, optional) – 每个batch加载样本数。
  • shuffle (bool, optional) – True则打乱数据.
  • sampler (Sampler, optional) – 采样器,如指定则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载
  • collate_fn (callable, optional) – 获取batch数据的回调函数,也就是说可以在这个函数中修改batch的形式
  • pin_memory (bool, optional) –
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。

Dataset和TensorDataset

所有其他数据集都应该进行子类化。所有子类应该override __len____getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

TensorDataset是Dataset的子类,已经复写了__len____getitem__方法,只要传入张量即可,它通过第一个维度进行索引。
TensorDataset示例
所以TensorDataset说白了就是将输入的tensors捆绑在一起,然后__len__是任何一个tensor的维度,__getitem__表示每个tensor取相同的索引,然后将这个结果组成一个元组,源码如下,要好好理解它通过第一个维度进行索引的意思(针对tensors里面的每一个tensor而言)。

class TensorDataset(Dataset):
	def __init__(self,*tensors):
		assert all(tensors[0].size(0)==tensor.size(0) for tensor in tensors)
		self.tensors = tensors
	def __getitem__(self,index):
		return tuple(tensor[index] for tensor in self.tensors)
	def __len__(self):
		return self.tensors[0].size(0)

Sampler和RandomSampler

Sampler与Dataset类似,是采样器的基础类。

每个采样器子类必须提供一个__iter__方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的__len__方法。

所以Sampler必然是关于索引的迭代器,也就是它的输出是索引。

而RandomSampler与TensorDataset类似,RandomSamper已经实现了__iter____len__方法,只需要传入数据集即可。
在这里插入图片描述
猜想理解RandomSampler的实现方式,考虑到这个类实现需要传入Dataset,所以__len__就是Dataset的__len__,然后__iter__就可以随便搞一个随机函数对range(length)随机即可。

综合示例

结合TensorDataset和RandomSampler使用Dataloader
在这里插入图片描述
这里即可理解Dataloader这个数据加载器其实就是组合数据集和采样器的组合。所以那就是先根据Sampler随机拿到一个索引,再用这个索引到Dataset中取tensors里每个tensor对应索引的数据来组成一个元组。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值