DataLoader 的 collate_fn 解释与示例教程

导包

import torch
from torch.utils.data import Dataset
from typing import Any

数据

class CustomDataset(Dataset):
    
    def __init__(self, length) -> None:
        super().__init__()
        self.length = length
    
    def __getitem__(self, index=None):
        w1 = 3.14
        w2 = 4.27
        w = torch.tensor([w1, w2])
        feature = torch.rand(2) * 10
        noise = torch.randn_like(feature) * 0.01
        label = torch.matmul(w, feature.t())
        feature += noise
        # return feature, label.view(1)
        return feature, label
    
    def __len__(self):
        return self.length

dataset = CustomDataset(4)

Dataloader

dataloader = torch.utils.data.DataLoader(
                    			dataset, 
                    			batch_size=2, 
								)

for feature, label in dataloader:
    print(feature.shape, label.shape)

下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature),构建出 batch 数据;

torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])

常量直接拼接;
向量则会在前面添加一个 batch 纬度;

collate_fn

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;

如上述默认的输出结果所示:label.shape 为 torch.Size([2]),笔者想通过 collate_fn 修改 label.shapetorch.Size([2, 1]),下述代码实现这个功能:

def collate_fn(item):
    feature, label = zip(*item)
    feature = torch.stack(feature)
    label = torch.stack(label)
    label = label.view(-1, 1)
    return feature, label
    
dataloader = torch.utils.data.DataLoader(
                    			dataset, 
                    			batch_size=2, 
                    			collate_fn=collate_fn
								)

for feature, label in dataloader:
    print(feature.shape, label.shape)

输出如下:

torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])

collate_fn(item),传入的item的数据为:

[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]

feature, label = zip(*item) 故通过zip(*item)的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack方法将其拼接出 batch 形状的数据。

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
`collate_fn` 是在 PyTorch 的 `DataLoader` 中使用的一个参数,用于自定义数据在批量加载过程中的拼接方式。它接受一个批量的样本数据列表作为输入,并返回一个包含了对应字段拼接后的批量数据。 在数据加载过程中,`DataLoader` 会将每个样本数据传递给 `collate_fn` 函数进行处理。`collate_fn` 函数的主要作用是对样本数据进行定制化的处理,例如进行填充、截断、变换等操作,以满足模型的输入要求。 典型的 `collate_fn` 函数可以执行以下操作: - 对样本数据进行填充或截断,使得一个批次中的所有样本具有相同的长度。 - 将输入数据转换为张量形式,例如将文本转换为索引序列或将图像转换为张量。 - 对样本数据进行其他定制化的预处理操作,例如数据标准化、增强等。 下面是一个示例的 `collate_fn` 函数: ```python def collate_fn(batch): # 将批次中的样本数据分别取出 inputs = [item['input'] for item in batch] labels = [item['label'] for item in batch] # 处理输入数据,例如进行填充或截断 inputs = pad_sequence(inputs, batch_first=True) # 将输入和标签转换为张量 inputs = torch.tensor(inputs) labels = torch.tensor(labels) return {'input': inputs, 'label': labels} ``` 在这个示例中,假设每个样本数据是一个字典,包含了输入数据和标签。`collate_fn` 函数首先将批次中的输入数据和标签分别取出,然后对输入数据进行填充操作,使用 `pad_sequence` 函数对输入数据进行批量填充,并设置 `batch_first=True` 来保证批次维度在第一维。最后,将输入数据和标签转换为张量形式,并以字典的形式返回。 需要注意的是,`collate_fn` 函数的实现应根据你的数据集和模型的需求进行定制化,确保返回的批量数据符合模型的输入要求。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

jieshenai

为了遇见更好的文章

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值