报错截图:
问题描述:
我之前是在一张3090ti上跑的代码,没有报错,pytorch版本是1.11,python是3.7,cuda版本是11.0。当我切换到两个卡(一个是3090ti,另一个是3060)运行时出现这个错误。感觉应该是torch版本和显卡兼容问题,pytorch版本高了。但是降级的话很麻烦。因此看看能不能修改下源码。
解决方案:
那么看报错的地方:
这句话:
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)
将next(self.parameters()).dtype修改成torch.float32:
extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)
再次运行代码,解决问题!!