本人在跑Bert+CRF代码时候遇到此问题:发生异常:RuntimeError CUDA error: device-side assert triggered
原因分析:
在debug过程中,代码在cuda无法查看具体出错的原因,这个时候我们需要将代码更改为在cpu运行(经过高人指点:cuda发生错误,都可以在cpu版本找到具体的错误原因)
解决方法
更改为cpu运行后,发现是因为标签越界了,更改正确即可
标签越界原因分析
我在文本补全时,将pad编号设置为-100(因为想要pad在交叉熵计算是失效)这个编号导致了错误发生。
经过实验,我发现通过pytroch-crf库调用的crf无法处理编号为负数的标签。事实上我们也无需设置特定的标签令它失效,因为只要将正确的attention_mask传进crf,它就能自动忽略pad的token