前言
之前看了不到pytorch代码,对Dataloader的大部分参数都比较了解,今天看代码时,发现了一个参数collate_fn ,之前论文代码没怎么见过,也就自动忽略了,今天既然遇到了,就突然来了好奇心,想搞清楚用途及用法,以下为正文。
问题及实验
- 问题
今天看代码时出现如下问题:
对Dataloader参数中的collate_fn甚感好奇,故想一探究竟。
2. 实验
1):myDataset()类中__getitem__方法返回的数据
代码如下:
测试结果:
可见myDataset()类中__getitem__方法返回值为两个,网络的输入数据为128x40的tensor,输出是个分类标签数据。
2):Dataloader 运行过程
过程:首先Dataloader 会根据batch参数生成一个长度为batch值的列表,列表的值是myDataset()类中__getitem__()的参数,如果shuffle为True ,列表的值就是从0到len(data)中随机抽样索引。然后列表的索引值会依次送入__getitem__()方法,最终返回一个列表的数据,该列表数据会作为collate_fn 函数的参数传入,最终得到一个batch的数据。其中collate_fn 函数可以使用系统默认的也可以使用自己设计的,非常灵活。
debug验证:
程序:
debug1:
进入self._next_data():
index 即为根据batch参数得到的列表:
进入fetch函数:
debug如下:
即fetch函数通过传入index列表得到一个新的列表数据,然后该列表数据通过collate_fn()函数得到最终数据。
--------------------------------------------------------------分割线----------------------------------------------------------
以下无关Dataloarder的使用方法,仅仅是研究以下这里的自定义collate_fn函数的功能。
进入collate_fn函数:
关于zip的拆包功能研究:
测试结果:
通过pad_sequence得到最终结果:
参考文章:
https://blog.csdn.net/dong_liuqi/article/details/114521240