【torch报错处理】RuntimeError: probability tensor contains either inf
, nan
or element < 0
速通版
本次报错解决方案是torch版本降级:2.4cu118 -> 2.1cu118
写于此随意记录,现在还是不明白原因,求大佬指点 2024/10/03
问题分析
RuntimeError: probability tensor contains either inf
, nan
or element < 0
这个报错网络上有很多朋友遇到,也有不少解决方案被提出。该报错的原因主要是因为计算出来的probability太大了,超过了当前数据类型能表示的最大值。网上搜到的一些解决方案有:
- 采用
from_pretrained
载入模型时,传入dtype
参数,用torch.bfloat16
替代torch.float16
。
分析:该方案合理,bfloat16比float16表示大数有优势。不过很可惜本次我遇到的时候已经使用了bfloat16 - 网上也有同学说需要把
tokenzier.pad_token='[PAD]' or tokenizer.eos_token
改成tokenzier.pad_token=tokenizer.unk_token
。还有的人说pad right不行,要改成pad left。
分析:有人说这样成功了,但是本人场景下实验不成功。因为本人实验的时候batchsize=1也根本没有pad的过程…
观察
本人实验时用了8个样本(8个prompt),这8个样本都会产生上述报错异常。
为了理解这个报错,尝试了几种方式试图找到原因:
- 尝试1: 更换模型,从llama3改成llama2
结论:有变化,4个样本没产生报错正常输出结果了 - 尝试2: 调整温度系数
结论:调高温系数,没观察到变化(可能调整到不够高) - 尝试3: 更换prompt
结论:将复杂的prompt更换成简单的prompt,比如"how are you",模型可以稳定的不报错的输出正确的内容 - 尝试4: 32bit!
结论:完全正常!
这说明,在复杂的prompt(比较少见的样本)下,模型很容易被精度卡脖子。但是我不可能一直用32bit啊!!而且我的任务就是很难。
成功的操作
正如前面所说,torch降级。2.4降到2.1就解决了。
是不是torch在实现softmax或者是什么地方他们的bfloat16精度范围不一样?为什么低版本比高版本在这个问题上还更加鲁棒?有没有大佬知道一些呢?我啥都不知道OMG
之前碰到过 torch.tril 函数不支持 bfloat16的报错(还有其他一堆…),通过torch 2.0升级到2.4解决了,这次2.4又降回到2.1。呃,留给我能用的版本号越来越少了…TAT