今天在使用pytorch1.7训练模型,调用apex时出现了该问题:
Doubletensor是64位float类型,Floattensor是32位float类型,调试程序发现第二次迭代时,变量v和msd[k]的数据类型不同,如图所示:
因此在条件循环后添加程序:
if v.dtype != msd[k].dtype:
v = v.to(msd[k].dtype)
修改后如图:
解决问题!