nlp中常用DataLoader中的collate_fn,对batch进行整理使其符合bert的输入

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
此处的collate_fn,是一个函数,会将DataLoader生成的batch进行一次预处理

假设我们有一个Dataset,有input_ids、attention_mask等列:
使用torch创建dataloder时,如果使用默认的collate_fn(default_collate),输出的batch中,input_ids,和token_type_ids,attention_mask都是长度为 sequence_length 的列表(如果input_ids都已经pad到sequence_length),
列表的每个元素是大小为[batch_size]的tensor,
如果input_ids的长度不相等还会报错。

这并不是我们想要的模型的输入格式。
我们希望一个batch中,input_ids的应该是shape=(batch_size,max_seq_length)的tensor

默认情况下:

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=default_collate)
for batch in train_loader:
    print(batch)
    break
#如下所示,batch_size=3,所以input_ids都是长度为3的tensor组成的列表,这不是我们想要的
{'input_ids': List[torch.Tensor], 'attention_mask': List[torch.Tensor]}
Tensor的shape为(batch_size)(3)

这样是无法直接把这个batch输入bert的,必须把input_ids,和token_type_ids,attention_mask都转化为大小为 [ batch_size,sequence_length ] 的tensor才能输入bert。
所以,需要定义自己的collate_fn函数,对batch进行整理(使用torch.stack函数对列表进行拼接):

#定义collate_fn,把input_ids,和token_type_ids,attention_mask列转化为tensor
def collate_fn(examples):
    batch = default_collate(examples)
    batch['input_ids'] = torch.stack(batch['input_ids'], dim=1)
    batch['token_type_ids'] = torch.stack(batch['token_type_ids'], dim=1)
    batch['attention_mask'] = torch.stack(batch['attention_mask'], dim=1)
    return batch

train_loader = DataLoader(dataset, batch_size=3, shuffle=True, collate_fn=collate_fn)
#输出第一个batch
for batch in train_loader:
    print(batch)
    break

结果如下,可以直接输入bert模型

output=model(**batch)
batch
{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor}
tensor的shape=(batch_size,max_seq_length),符合bert的输入格式

然而,以上方法必须要求所有input_ids长度一致,否则依然会报错。
其实tokenizer类已经提供了更简单的方法 tokenizer.pad()

    def collate_fn_SentenceClassify(features,tokenizer):
        #batch = default_collate(examples) 不再需要这句话
        #将batch中的input_ids和attention_mask等补齐
        batch = tokenizer.pad(
            features,
            padding=True,
            max_length=None,
        )
        return batch

此处,features是List[Dict[str,Tensor]], batch被整理成了Dict[str,Tensor],符合输入格式。
和DataCollatorWithPadding 效果一样

from transformers import DataCollatorWithPadding 

对于更复杂的任务,例如文本生成任务,可以直接调用transformers库中的collate_fn,例如

from transformers import DataCollatorForSeq2Seq

具体用法可以看huggingface官网
更多DataCollator

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值