【Hugging Face】transformers 库中的 DataCollatorWithPadding: 对不同长度的输入文本进行动态填充(padding)

Hugging Face transformers 库中的 DataCollatorWithPadding

DataCollatorWithPaddingtransformers 库提供的一个 数据整理器(Data Collator),用于 对不同长度的输入文本进行动态填充(padding),使其可以批量(batch)输入到 Transformer 模型中进行训练或推理。


1. 为什么需要 DataCollatorWithPadding

在 NLP 任务中,模型通常需要 固定长度的输入,但不同文本的长度往往不同,例如:

["Hello!", "How are you?", "This is a longer sentence."]

DataLoader 处理中,必须 确保 batch 里的样本具有相同长度,常见方法有:

  • 固定长度填充(padding):所有样本填充到 max_length
  • 动态填充(dynamic padding):仅填充到 batch 内 最长的文本,减少计算浪费

Hugging Face 提供 DataCollatorWithPadding 进行动态填充,避免手动处理 max_length


2. DataCollatorWithPadding 的基本用法

from transformers import AutoTokenizer, DataCollatorWithPadding

# 加载 BERT Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 创建 DataCollator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

3. 使用 DataCollatorWithPadding 处理文本

假设我们有一个批量样本:

batch_sentences = ["Hello!", "How are you?", "This is a longer sentence."]

使用 tokenizer 进行分词:

tokenized_inputs = [tokenizer(text) for text in batch_sentences]
print(tokenized_inputs)

输出(不同长度的 token 序列):

[
 {'input_ids': [101, 7592, 999, 102], 'attention_mask': [1, 1, 1, 1]},
 {'input_ids': [101, 2129, 2024, 2017, 102], 'attention_mask': [1, 1, 1, 1, 1]},
 {'input_ids': [101, 2023, 2003, 1037, 2936, 6251, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
]

可以看到,每个句子的 input_ids 长度不同,无法直接组成 batch。


4. 使用 DataCollatorWithPadding 进行动态填充

import torch

# 使用 DataCollator 进行填充
batch = data_collator(tokenized_inputs)

# 转换为 PyTorch Tensor
batch = {k: torch.tensor(v) for k, v in batch.items()}

print(batch)

输出(所有样本填充到 batch 内最大长度 8):

{
 'input_ids': tensor([[ 101,  7592,  999,  102,    0,    0,    0,    0],
                      [ 101, 2129, 2024, 2017,  102,    0,    0,    0],
                      [ 101, 2023, 2003, 1037, 2936, 6251, 1012,  102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 0, 0, 0, 0],
                           [1, 1, 1, 1, 1, 0, 0, 0],
                           [1, 1, 1, 1, 1, 1, 1, 1]])
}

填充特点

  • input_ids 末尾填充 0
  • attention_mask 位置填充 0(表示 padding)

5. DataCollatorWithPaddingDataLoader 结合

在 PyTorch DataLoader 训练时,DataCollatorWithPadding 可自动填充 batch 内的样本

from torch.utils.data import DataLoader
from datasets import load_dataset

# 加载 IMDb 数据集
dataset = load_dataset("imdb", split="train")

# 预处理函数
def preprocess(example):
    return tokenizer(example["text"], truncation=True)

# 应用 `map` 进行批量处理
encoded_dataset = dataset.map(preprocess, batched=True)

# 创建 DataLoader
train_dataloader = DataLoader(encoded_dataset, batch_size=8, collate_fn=data_collator)

# 取一个 batch
batch = next(iter(train_dataloader))
print(batch)

特点

  • DataLoader 通过 collate_fn=data_collator 进行 自动填充
  • 无需手动指定 max_length
  • 支持变长批量处理,节省计算资源

6. DataCollatorWithPadding 的常见参数

data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding=True,             # 是否填充
    max_length=None,          # 可选:最大长度(默认 None,使用 batch 内最长文本)
    pad_to_multiple_of=8,     # 可选:填充到 8 的倍数(GPU 计算更快)
    return_tensors="pt"       # 返回 PyTorch Tensor
)

常见参数解析

参数作用默认值
tokenizer指定 Tokenizer必须
padding是否进行填充True
max_length限制最大序列长度None
pad_to_multiple_of是否填充到某个倍数(如 8,加速 GPU 计算)None
return_tensors指定返回数据类型("pt" / "tf" / "np"None

7. DataCollatorWithPadding vs tokenizer(..., padding=True)

方法适用场景是否动态填充训练/推理
tokenizer(..., padding=True)适用于 静态填充,需要提前 max_length❌ 固定长度✅ 适用于推理
DataCollatorWithPadding动态填充,根据 batch 长度调整✅ 适用于训练

示例对比

  • 使用 tokenizer 进行固定填充
batch = tokenizer(batch_sentences, padding=True, max_length=10, return_tensors="pt")
  • 使用 DataCollatorWithPadding 进行动态填充
batch = data_collator([tokenizer(text) for text in batch_sentences])

区别

  • tokenizer(..., padding=True) 填充到固定 max_length
  • DataCollatorWithPadding 仅填充到 batch 内最长文本

8. DataCollatorWithPadding 适用于哪些任务?

任务适用
文本分类
命名实体识别(NER)
机器翻译
摘要生成
文本生成(GPT-2, T5)❌ 不适用

对于 文本生成(GPT-2, T5),推荐 DataCollatorForSeq2SeqDataCollatorForLanguageModeling


9. 总结

  1. DataCollatorWithPadding 是 Hugging Face transformers 提供的批量填充工具,适用于 动态填充 Transformer 输入数据
  2. 动态填充 batch 内最长文本,比 max_length 固定填充更高效
  3. 适用于 PyTorch DataLoader,避免手动处理 max_length
  4. 支持 pad_to_multiple_of=8 加速 GPU 计算
  5. 适用于分类、NER、翻译等任务,但不适用于 GPT-2/T5 语言模型任务
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值