前提,cuda 11.6,pytorch 1.9.0,python 3.8
torch 没有autocast 这个属性,1.6.0版本以上的pytorch才有,我的版本是1.9.0,因为源代码说的是1.8.0以上都行,我自己试过很多个版本了,也升级过最新的,最新的又显示,我的cuda driver too old ,搞一下午没整出来
解决方法如下:
把
with torch.autocast(device.type,enabled=use_fp16):
改成:
with torch.cuda.amp.autocast(enabled=use_fp16):
我是偶然才发现这个问题的,无语了、、、、、
注意两句话参数不一样,传参的时候别搞错了