DataLoader的collate_fn参数

引言

在基于Pytorch的深度学习的代码中基本都是用Dataset和DataLoader来加载数据,本文起源于一个疑惑:看到某Dataset类中定义的collate_fn中有这么一句:

def collate_fn(self, batch)  
    ims, classes = list(zip(*batch)))

我一想,本来从DataLoader加载出来的batch不就是数据跟标签是分开的吗,这里怎么又要给他分开?于是研究了一波:

先放上用于演示的代码:

from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn

x = [[1,2],[3,4],[5,6],[7,8]]
y = [[3],[7],[11],[15]]

X = torch.tensor(x).float()
Y = torch.tensor(y).float()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
X = X.to(device)
Y = Y.to(device)

class MyDataset(Dataset):
    def __init__(self,x,y):
        self.x = torch.tensor(x).float()
        self.y = torch.tensor(y).float()
    def __len__(self):
        return len(self.x)
    def __getitem__(self, ix):
        return self.x[ix], self.y[ix]
ds = MyDataset(X, Y)

dl = DataLoader(ds, batch_size=2, shuffle=True)

class MyNeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_to_hidden_layer = nn.Linear(2,8)
        self.hidden_layer_activation = nn.ReLU()
        self.hidden_to_output_layer = nn.Linear(8,1)
    def forward(self, x):
        x = self.input_to_hidden_layer(x)
        x = self.hidden_layer_activation(x)
        x = self.hidden_to_output_layer(x)
        return x

mynet = MyNeuralNet().to(device)
loss_func = nn.MSELoss()
from torch.optim import SGD
opt = SGD(mynet.parameters(), lr = 0.001)

import time
loss_history = []
start = time.time()
for _ in range(50):
    for data in dl:
        x, y = data
        opt.zero_grad()
        loss_value = loss_func(mynet(x),y)
        loss_value.backward()
        opt.step()
        loss_history.append(loss_value)
end = time.time()
print(end - start)

val_x = [[10,11]]

val_x = torch.tensor(val_x).float().to(device)

mynet(val_x)

代码是jupyter写的,用pycharm自己改改也能运行,要完成的事情就是训练一个模型计算两数之和

现在没用collate_fn这里print一下data看得出来,数据和标签是分开的,长度都是一个batch(2)的大小

现在我把collate_fn写进dataset里,直接打印一下batch看看效果(这里需要把batch return了,因为dataloader如果选择了用collate_fn参数的话最终的输出是用collate_fn指定的函数的返回值给出的(这里设置函数名同参数名,这是默认的,可以自定义函数名))

别忘改参数:

结果如下:

可以发现不止collate_fn函数内部的batch值是一个个样本对(原来是图片和真值分开),dataloader中返回的data也是一个个样本对,这就涉及到collate_fn的工作原理了:

collate_fn的作用:

  1. collate_fn: 这是一个函数或者None。当数据集中的样本具有不同的大小或类型时,collate_fn函数用于将样本组合成一个批次。通常,它会处理样本的填充、对齐或转换等操作,以便批次中的所有样本具有相同的形状或类型。如果样本具有相同的大小或类型,可以将collate_fn设置为None

  2. collate_fn所指向的函数本来就是对一个批次的数据进行操作,在数据增强,数据预处理等方面也有着广泛应用

dataloader的工作流程:

collate_fn的情况下,随机/不随机(shuffle)选择batch个索引传入dataset里的__getitem__(self, idx)得到对应的数据(样本对),经过处理分别传出 数据 和 真值(样本对被自动拆开重组了)

collate_fn的情况下,随机/不随机(shuffle)选择batch个索引传入dataset里的__getitem__(self, idx)得到对应的数据,将这些数据(样本对)传入collate_fn指定函数进行处理,因为没有自动的将样本对拆开重组,所以一般需要手工操作,正如文章开头的那段代码

def collate_fn(self, batch)  
    ims, classes = list(zip(*batch)))

一样,手工拆开重组之后对数据进行定义的操作(数据增强,其他预处理等等)再将数据和标签分别输出(这样是为了模仿不用collate_fn时将样本对自动拆开重组)

总结:

collate_fn函数会帮助批量化地自定义地处理数据,是个很好的很灵活的功能

这个视频会帮助理解

PyTorch DataLoader工作原理可视化 (qq.com)

  • 26
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值