本篇文章仅是针对自己最近项目中遇到的问题进行梳理沉淀,并不能复现解决相同的问题。但是或许可以提供一个解决问题的思路:尝试分析原始数据是否符合模型要求
背景
基于presumm和hiersumm两个文本摘要项目,希望使用bert encoder对数据进行编码。
任务详情
- 输入:一串文本数据
- 输入格式:二进制.pt文件
- 模型:bert,Transformer
- 输出:一串文本数据
问题起因
模型在加载数据时报错,报错内容如下
TypeError: forward() missing 1 required positional argument: 'attention_mask'
具体操作步骤
- 将文本数据(src,tgt,src_txt,tgt_txt)转换为.pt文件
- 将模型encoder设置为bert encoder
- 模型加载.pt文件
- 出现上述报错
分析
遇到该错误就在找到底attention_mask
在哪,查看bert源码发现是模型中forward()
函数,因为forward()函数是模型在迭代(如梯度下降、反向传播)才会调用,并不知道如何将attention_mask
怎么传入。因此也尝试发帖问了一些问题(问题描述的不够清楚,因为自己本身也没意识到问题根源在哪里,报错信息只是个表象)
CSDN问题链接
知乎提问链接