Huggingface transformer的Trainer中data_collator的使用

什么时候使用

Transformers Trainer的文档中可知,Trainer函数有一个参数data_collator,其值也为一个函数,用于从一个list of elements来构造一个batch。这个函数其实就是torch.utils.data.DataLoader中的collate_fn。如果还不知道collate_fn是做何用处,请参考这篇文档
要用到这个函数,要符合如下两个条件:

  1. Trainer的参数train_dataseteval_dataset是torch.utils.data.Dataset或torch.utils.data.IterableDataset的实体
  2. train_dataseteval_dataset(torch.utils.data.Dataset)加载入DataLoader后,得到的batch不可用,还不能直接加入到model的forward中计算

如何用

这里假设读者已经知道torch.utils.data.DataLoader的collate_fn用法,只介绍Trainer的data_collator和torch.utils.data.DataLoader的collate_fn的差异。
差异就是,输出格式!torch.utils.data.DataLoader的collate_fn的输出可以是各种格式,但Trainer的data_collator的输出只能是一个dict,这个dict的键必须包含“input_ids”,“attention_mask”等transformers models正常运算必要的参数的名称,如果需要,也可以加入任何transformers model.forward()可接受的参数名,而这些键对应的值也必须是transformers model中该键应该对应的输入值。
如果想让模型自动训练loss,还要在这个dict中加入如下键值对:{“labels”: labels in tensor type},这样模型的输出里就有loss了。

为什么呢?

看两段源码其实就差不多明白了:
在这里插入图片描述
在这里插入图片描述
第一张图中,这个DataLoader就是一个纯粹的torch.utils.data.DataLoader,self.data_collator就是输入的data_collator函数。所以,这个data_collator就彻彻底底是一个DataLoader的collate_fn啊
第二张图中,input就是如下迭代的结果(其中的dataloader就是第一张图中的dataloader)

for step, inputs in enumerate(DataLoader)

所以,inputs的键值对必须要与model.forwards()的参数相对应也是显然的

  • 15
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值