Runtime error: expected scalar type Float but found Double
使用pytorch 训练网络的时候,遇到了这个错误,我也是一脸蒙逼,后来把输入到网络里面的数据流给打印出来,发现我的输入是:
<class ‘torch.Tensor’> torch.Size([128, 4]) torch.float64
<class ‘torch.Tensor’> torch.Size([128]) torch.int64
明明是float64 , 怎么说是double 呢
后来发现 torch.nn.Linear 支持的时float32 ,于是把数据格式转换为float32试试
发现程序通过了
因此,这个错误提示的意思应该是你的数据类型不正确!