最近开始学习pytorch,在训练时出现标题所示问题。浏览了很多方法后,总结出出现这个问题的主要原因是输入的数据类型与网络参数的类型不符。
Input type为torch.cuda.FloatTensor(GPU数据类型), weight type(即net.parameters)为torch.FloatTensor(CPU数据类型)
网上资料大多数的解决方法是 将网络放到GPU上。
but 没有用。
我在论文作者的代码上面直接加了点东西。
就是说,最终的解决方法。
是将你加进去的模块,在上面的init中,self.模块=啥的。
然后在后面的forword里面再用,最后解决了的。