理解Pytorch中的collate_fn函数

PyTorch中的DataLoader是最常用的类之一,这个类有很多参数(14 个),但大多数情况下,你可能只会使用其中的三个:dataset、shuffle 和 batch_size。其中collate_fn是比较少用的函数,这对初学者来说是一个容易混淆的概念。下面将简要探讨 PyTorch 如何创建批次,并了解如何根据我们的需求修改其默认行为。

批处理

首先,先创建一个数据,用data变量标视

import torch
from torch.utils.data import DataLoader
import numpy as np

data = np.array([
    [0.1, 7.4, 0],
    [-0.2, 5.3, 0],
    [0.2, 8.2, 1],
    [0.2, 7.7, 1]])
print(data)

如果我们加载一个批次出来( shuffle=False以消除随机性):

loader = DataLoader(data, batch_size=2, shuffle=False)
batch = next(iter(loader))
print(batch)

# tensor([[ 0.1000,  7.4000,  0.0000],
#         [-0.2000,  5.3000,  0.0000]], dtype=torch.float64)

结果符合预期的,我们来解释一下已经做了什么:

  • 加载器从数据集中选择了 2 个样本。
  • 这些样本被转换为张量(2 个大小为 3 的样本)。
  • 创建并返回一个新的张量 (2x3)。

默认设置还允许我们使用字典。 让我们看一个例子:

from pprint import pprint
# now dataset is a list of dicts
dict_data = [
    {'x1': 0.1, 'x2': 7.4, 'y': 0},
    {'x1': -0.2, 'x2': 5.3, 'y': 0},
    {'x1': 0.2, 'x2': 8.2, 'y': 1},
    {'x1': 0.2, 'x2': 7.7, 'y': 10},
]
pprint(dict_data)
# [{'x1': 0.1, 'x2': 7.4, 'y': 0},
# {'x1': -0.2, 'x2': 5.3, 'y': 0},
# {'x1': 0.2, 'x2': 8.2, 'y': 1},
# {'x1': 0.2, 'x2': 7.7, 'y': 10}]

loader = DataLoader(dict_data, batch_size=2, shuffle=False)
batch = next(iter(loader))
pprint(batch)
# {'x1': tensor([ 0.1000, -0.2000], dtype=torch.float64),
#  'x2': tensor([7.4000, 5.3000], dtype=torch.float64),
#  'y': tensor([0, 0])}

Dataloader简单易用,可以正确地从字典列表中重新打包数据。 当你的数据采用 JSON格式时,此功能非常方便。

自定义collate函数

Dataloader默认设置能覆盖大部分场景的数据读取,但默认设置有一个很大的限制——批数据必须处于同一维度。 假设我们有一个 NLP 任务,并且数据是分词后的文本。

# values are token indices but it does not matter - it can be any kind of variable-size data
nlp_data = [
    {'tokenized_input': [1, 4, 5, 9, 3, 2],
     'label':0},
    {'tokenized_input': [1, 7, 3, 14, 48, 7, 23, 154, 2],
     'label':0},
    {'tokenized_input': [1, 30, 67, 117, 21, 15, 2],
     'label':1},
    {'tokenized_input': [1, 17, 2],
     'label':0},
]
loader = DataLoader(nlp_data, batch_size=2, shuffle=False)
batch = next(iter(loader))

这样强行去压成一个batch存储,会引发错误:

/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
     80         elem_size = len(next(it))
     81         if not all(len(elem) == elem_size for elem in it):
---> 82             raise RuntimeError('each element in list of batch should be of equal size')
     83         transposed = zip(*batch)
     84         return [default_collate(samples) for samples in transposed]

RuntimeError: each element in list of batch should be of equal size

报错信息显示不能创建非矩形张量。顺便说一句,可以看到触发错误的是 default_collate函数。

如何修改? 有两种解决方案:

  • 将整个数据集填充到最长的样本。
  • 在Batch data创建期间进行动态填充。

第一个方法最简单,但是非常耗内存,极端条件下,我们有1万条数据,其中9999条数据长度是10,而只有1条数据长度是1000,那么所有的数据都需要pad数值,使得长度填充到1000。这样有99%的内存占用都是无意义的。

pad-max

另一种方法是动态填充数据。 当选择这一个batch的样本时,我们只需要将数据填充到这一个batch最长的样本长度即可。另外,将数据按照长度进行排序,则填充的数量将是最小的。 如果有一些非常长的序列,它们只会影响它们的这一个batch的效率,而不是整个数据集。

pad-batch

具体如何实现呢?

from torch.nn.utils.rnn import pad_sequence #(1)

def custom_collate(data): #(2)
    inputs = [torch.tensor(d['tokenized_input']) for d in data] #(3)
    labels = [d['label'] for d in data]

    inputs = pad_sequence(inputs, batch_first=True) #(4)
    labels = torch.tensor(labels) #(5)

    return { #(6)
        'tokenized_input': inputs,
        'label': labels
    }

loader = DataLoader(
  	nlp_data, 
    batch_size=2, 
    shuffle=False, 
    collate_fn=custom_collate
) #(7)

iter_loader = iter(loader)
batch1 = next(iter_loader)
pprint(batch1)
batch2 = next(iter_loader)
pprint(batch2)

# {'label': tensor([0, 0]),
#  'tokenized_input': tensor([
#   [  1,   4,   5,   9,   3,   2,   0,   0,   0],
#   [  1,   7,   3,  14,  48,   7,  23, 154,   2]
# ])}

# {'label': tensor([1, 0]),
#  'tokenized_input': tensor([
#   [  1,  30,  67, 117,  21,  15,   2],
#   [  1,  17,   2,   0,   0,   0,   0]])}

代码功能如下:

  • 我们使用 pad_sequence进行填充
  • custom_collate作为参数传递给DataLoader
  • 在运行时对inputs进行动态填充

总结

collate_fn是一个很少用的函数,但对提升训练效率有很大的帮助。

  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值