pytorch之深入理解collate_fn

本文详细介绍了PyTorch中DataLoader的collate_fn功能,它用于整理批量数据。通过实例展示了如何自定义collate_fn以处理不同维度的数据,并讨论了在数据不规整时collate_fn的重要性。尽管默认的collate_fn在大多数情况下足够,但在特殊情况下,如数据维度不固定,自定义collate_fn可以避免错误。然而,对于定长输入的神经网络,collate_fn的应用场景有限。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言

import torch.utils.data as tud

collate_fn:即用于collate的function,用于整理数据的函数。
说到整理数据,你当然也要会用tud.Dataset,因为这个你定义好后,才会产生数据嘛,产生了数据我们才能整理数据嘛,而整理数据我们使用collate_fn

1.dataset

我们必须先看看tud.Dataset如何使用,以一个例子为例:

class mydataset(tud.Dataset):
    def __init__(self,data):
        self.data=data
    def __len__(self):#必须重写
        return len(self.data)
    def __getitem__(self,idx):#必须重写
        return self.data[idx]
#构造训练数据
a=np.random.rand(4,3)#4个数据,每一个数据是一个向量。
print(a)

在这里插入图片描述

#制作dataset
dataset=mydataset(a)
len(dataset)#调用了你上面定义的def __len__()那个函数
#4
dataset[0]#调用了你上面定义的def __getitem__()那个函数,传入的idx=0,也就是取第0个数据。
#array([0.56998216, 0.72663738, 0.3706266 ])

2.dataloader之collate_fn

dataloader=tud.DataLoader(dataset,batch_size=2)

batch_size=2即一个batch里面会有2个数据。我们以第1个batch为例,tud.DataLoader会根据dataset取出前2个数据,然后弄成一个列表,如下:

batch=[dataset[0],dataset[1]]
batch

[array([0.56998216, 0.72663738, 0.3706266 ]),
array([0.3403586 , 0.13931333, 0.71030221])]

然后将上面这个batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值。如果你不指定这个函数是什么,那么会调用pytorch内部的collate_fn

也就是说,我们如果自己要指定这个函数,collate_fn应该定义成下面这个样子。

def my_collate(batch):#batch上面说过,是dataloader传进来的。
	***#你自己定义怎么整理数据。下面会说。
	real_batch=***
	return real_batch

那么pytorch内部默认的collate_fn函数长什么样呢?我们先观察下面的例子:

it=iter(dataloader)
nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)

tensor([[0.5700, 0.7266, 0.3706],
[0.3404, 0.1393, 0.7103]], dtype=torch.float64)

上面这个返回的结果就是real_batch。也就是collate_fn函数的返回值!!也就是说collate_fn将batch变成了上面的real_batch。

我们重新写一遍

batch:
[array([0.56998216, 0.72663738, 0.3706266 ]),
array([0.3403586 , 0.13931333, 0.71030221])]
real_batch:
tensor([[0.5700, 0.7266, 0.3706],
[0.3404, 0.1393, 0.7103]], dtype=torch.float64)

将batch变成上述real_batch很容易呀,就是把一个列表,变成了矩阵,我们也会!我们下面就来自己写一个collate_fn实现这个功能。

def my_collate(batch):
    real_batch=np.array(batch)
    real_batch=torch.from_numpy(real_batch)
    return real_batch
dataloader2=tud.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader2)
nex=next(it)#我们展示第一个batch经过collate_fn之后的输出结果
print(nex)

tensor([[0.5700, 0.7266, 0.3706],
[0.3404, 0.1393, 0.7103]], dtype=torch.float64)

这不就和默认的collate_fn的输出结果一样了嘛!

3.应用情形

通常,我们并不需要使用这个函数,因为pytorch内部有一个默认的。但是,如果你的数据不规整,使用默认的会报错。例如下面的数据。
假设我们还是4个输入,但是维度不固定的。之前我们是每一个数据的维度都为3。

a=[[1,2],[3,4,5],[1],[3,4,9]]
dataset=mydataset(a)
dataloader=tud.DataLoader(dataset,batch_size=2)
it=iter(dataloader)
nex=next(it)
nex

使用默认的collate_fn,直接报错,要求相同维度。
在这里插入图片描述
这个时候,我们可以使用自己的collate_fn,避免报错。

那还想问,能不能举个现实例子,什么时候长度会不规整?

一个最典型的例子其实就是自然语言处理领域了。我们有100个句子,这些句子长短不一。

聪明的你也许会想,为什么我不将每一个句子变成长度一样呢?例如使用填充法,使用0填充,将所有短的句子填充到最大长度。这样的话,100个句子长度一样,就不需要自己定义collate_fn了,用默认的即可。

这样确实可以,但是如果这100个句子中最长的那个句子长度特别大的话,意味着大量的句子需要大量填充,后果就是,输入到神经网络的时候会用很多内存。这个时候collate_fn就有用了,可以用来实现batch级别的填充(之前那种填充是数据集级别的填充),例如batch size为4,那么只需要重新实现collate_fn,将这4个句子填充到一样的长度即可。


完结撒花

评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

音程

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值