torch.utils.data.dataloader参数collate_fn简析

torch.utils.data.DataLoader是pytorch提供的数据加载类,初始化函数如下,

torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

dataset,batch_size等参数重要且容易理解,而collate_fn参数就不太直白,官方解释为:

collate_fn (callableoptional) – merges a list of samples to form a mini-batch

不明不白。

其实,collate_fn可理解为函数句柄、指针...或者其他可调用类(实现__call__函数)。 函数输入为list,list中的元素为欲取出的一系列样本。具体如下

indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])

其中self.sampler_iter即采样器,返回下一个batch中样本的序号,indices。

通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

以图像关键点训练数据采样举例:

采样器调用我们自定义数据类的__getitem__(self, idx)函数获取训练样本,假设__getitem__函数返回字典:

{
"image": [[...],[...]]#一副图像,tensor,格式1CHW
 "keypoints":[[x1,y1],[x2,y2],...]#图像中的关键点,tensor
}

那么通过sampler采样一个batch的样本时,返回的是一个list,格式如下

[
{"image": [[...],[...]],
 "keypoints":[[x1,y1],[x2,y2],...]},

{"image": [[...],[...]],
 "keypoints":[[x1,y1],[x2,y2],...]}
]

我们知道,神经网络在处理图像数据时,可以一次输入一个batch的数据,格式为(BCHW)的tensor,因此我们需要将数据变成如下格式

{
"images":[[[...]],[[...]]]#多幅图像,Tensor,格式:BCHW
"keypoints":[tensor,tensor]#每个元素都是一个list或tensor,对应与各image中的关键点
}

这个转换过程就可以通过collate_fn函数完成。

 

  • 13
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值