原因在于自己整理数据集的时候,使用了np.array,然后默认保存成float64,但是pytorch中默认是float32
首先找到代码出错的位置,将该处的数据类型转为float类型:
x = x.type(torch.FloatTensor)
原因在于自己整理数据集的时候,使用了np.array,然后默认保存成float64,但是pytorch中默认是float32
首先找到代码出错的位置,将该处的数据类型转为float类型:
x = x.type(torch.FloatTensor)