RuntimeError: expected dtype Double but got dtype Float 问题解决

利用Pytorch框架自己构建网络结构,在程序运行到“loss.backward()”的时候报错:
 

RuntimeError: expected dtype Double but got dtype Float (validate_dtype at ..\aten\src\ATen\native\TensorIterator.cpp:143)
(no backtrace available)

通过查询资料得知,该错误来自于输入数据的类型和模型参数类型不一致。因此最好在程序开始统一数据类型。

Pytorch里的tensor创建时默认是Torch.FloatTensor类型(torch.float32),

可通过在import语句后增加语句

torch.set_default_tensor_type(torch.DoubleTensor)

这样之后创建的变量类型都是Double类型(torch.float64)。

如果想要创建变量类型都是Float类型,在import后增加语句

torch.set_default_tensor_type(torch.FloatTensor)

后,执行卷积操作时会报错:

RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #3 'mat1' in call to _th_addmm_

这是因为numpy的默认数据类型为float64,如果根据torch.from_numpy创建tensor,如b = torch.from_numpy(a),a和返回的b共享一块内存࿰

  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值