Pytorch的DataLoarder中的collate_fn参数

该博客介绍了如何在PyTorch中自定义DataLoader的collate_fn函数来处理批量数据。通过示例代码展示了如何进行数据标准化,包括对输入和目标序列进行padding,以确保所有样本在批处理中具有相同的长度。此外,还解释了返回的各个参数的含义,例如输入序列、序列长度、目标序列和填充标记。这有助于在训练神经网络时有效地处理变长序列。
摘要由CSDN通过智能技术生成

使用方法

作为dataLoader的形参,不传入的时候使用默认的,可以自己定义。

DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)

自己定义:

def collate_fn(examples):
    """
    wfj:该函数表示对于batch_size中的每一个元素做以下一下的操作,通常用来进行数据的标准化工作
    """
    print("==========================")
    print(examples)
    print(len(examples))
    lengths = torch.tensor([len(ex[0]) for ex in examples])
    inputs = [torch.tensor(ex[0]) for ex in examples]
    targets = [torch.tensor(ex[1]) for ex in examples]
    # 对batch内的样本进行padding,使其具有相同长度
    inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"])
    targets = pad_sequence(targets, batch_first=True, padding_value=vocab["<pad>"])
    #输出的几个参数的解释:解释变量;每个解释变量的长度;被解释变量;是否为填充位的标记。
    return inputs, lengths, targets, inputs != vocab["<pad>"]

打印信息

在这里插入图片描述
我们的batch_size设置的是32。

解析

所以collate_fn接受的一个参数,就是Dataloader迭代取出的每个batch_size,我们可以在collate_fn中对每个batch_size的数据进行相关的操作和个性化的处理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值