torch.utils.data.DataLoader是pytorch提供的数据加载类,初始化函数如下,
torch.utils.data.DataLoader(dataset,batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
dataset,batch_size等参数重要且容易理解,而collate_fn参数就不太直白,官方解释为:
collate_fn (callable, optional) – merges a list of samples to form a mini-batch
不明不白。
其实,collate_fn可理解为函数句柄、指针...或者其他可调用类(实现__call__函数)。 函数输入为list,list中的元素为欲取出的一系列样本。具体如下
indices = next(self.sample_iter)
batch = self.collate_fn([dataset[i] for i in indices])
其中self.sampler_iter即采样器,返回下一个batch中样本的序号,indices。
通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。
以图像关键点训练数据采样举例:
采样器调用我们自定义数据类的__getitem__(self, idx)函数获取训练样本,假设__getitem__函数返回字典:
{
"image": [[...],[...]]#一副图像,tensor,格式1CHW
"keypoints":[[x1,y1],[x2,y2],...]#图像中的关键点,tensor
}
那么通过sampler采样一个batch的样本时,返回的是一个list,格式如下
[
{"image": [[...],[...]],
"keypoints":[[x1,y1],[x2,y2],...]},
{"image": [[...],[...]],
"keypoints":[[x1,y1],[x2,y2],...]}
]
我们知道,神经网络在处理图像数据时,可以一次输入一个batch的数据,格式为(BCHW)的tensor,因此我们需要将数据变成如下格式
{
"images":[[[...]],[[...]]]#多幅图像,Tensor,格式:BCHW
"keypoints":[tensor,tensor]#每个元素都是一个list或tensor,对应与各image中的关键点
}
这个转换过程就可以通过collate_fn函数完成。