问题描述: 在使用Pytorch的过程中,有时会遇到以下报错:
RuntimeError: Expected object of type torch.cuda.DoubleTensor but found type torch.cuda.FloatTensor for argument #3 'other'
这种错误是由于数据类型不匹配造成的。这种不匹配可能来自Pytorch各个层之间,也可能来自于使用Dataset和Dataloader导入来自Numpy的数据。后者有时更难以发现。
原因分析: 如果Pytorch的数据来源是Numpy,要十分注意在Numpy中,小数的默认数据类型是np.float
,但np.float
与np.float64
等价;在Pytorch中,默认数据类型是torch.float
,但float
与t