关于pytorch中关于nn.module的type问题。

RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same.

如题所示的错误为input的type时double的而weight的type时float的(weight的type在pytorch中默认为float),因此遇到这种问题一般有两种解决方案

1.将input的type由double更换为float
示例:

input = Variable(data)
input = input.float()

此种方案是将input的double更改为float毫无疑问会对矩阵内部数值产生影响。因此若是你的矩阵数值大到必须使用double那这个方案需要pass。若是你的矩阵数值使用float对矩阵本身没什么影响则可使用这条命令。

2.将nn.module的type由float更换为double
示例:

linear = nn.Linear(2, 2)
print(linear.weight)

linear.double()
print(linear.weight)
Parameter containing:
tensor([[-0.6322, -0.2967],
        [-0.5611,  0.0638]], requires_grad=True)
Parameter containing:
tensor([[-0.6322, -0.2967],
        [-0.5611,  0.0638]], dtype=torch.float64, requires_grad=True)

conv = nn.Conv2d(2, 64, 3, padding=1).double()
print(conv.weight)
Parameter containing:
tensor([[[[-0.2335, -0.0215,  0.0710],
          [-0.0369,  0.2331, -0.1877],
          [-0.1738,  0.1258, -0.0202]],

         [[-0.2055,  0.0458, -0.1719],
          [ 0.1927,  0.1866, -0.0046],
          [ 0.1286,  0.1005, -0.0137]]],


        [[[ 0.0427,  0.1982,  0.0761],
          [-0.2082, -0.1114,  0.0637],
          [ 0.1352,  0.0848,  0.1489]],

         [[ 0.0228,  0.2221, -0.1518],
          [-0.0462,  0.1292, -0.1872],
          [-0.0874,  0.0772,  0.1837]]],


        [[[ 0.2184,  0.2230, -0.0358],
          [-0.2187,  0.1524,  0.0409],
          [ 0.2244,  0.1070, -0.0077]],

         [[ 0.1362, -0.1353, -0.0493],
          [-0.0381, -0.0550, -0.0046],
          [ 0.1634,  0.1359, -0.1857]]],


        ...,


        [[[-0.2268, -0.1568,  0.0236],
          [ 0.0697, -0.0852,  0.0374],
          [-0.0336, -0.0743, -0.0871]],

         [[ 0.0218, -0.0523,  0.1762],
          [-0.2168, -0.1741,  0.1094],
          [ 0.0216,  0.1672, -0.1006]]],


        [[[-0.1174,  0.1897, -0.1503],
          [-0.1382, -0.1979,  0.1617],
          [ 0.2282,  0.2337,  0.1170]],

         [[-0.1433, -0.2325,  0.1023],
          [ 0.2197, -0.1562, -0.1585],
          [-0.0422,  0.0011,  0.0161]]],


        [[[ 0.1030,  0.1959, -0.1112],
          [-0.0733, -0.0093,  0.1591],
          [ 0.1522,  0.0894,  0.1800]],

         [[ 0.0505,  0.1076, -0.0392],
          [ 0.1726,  0.0952, -0.1742],
          [-0.2179,  0.2116,  0.0131]]]], dtype=torch.float64,
       requires_grad=True)

如代码所示,将你定义conv层或者maxpooling、liner层加后缀.double()。则自动强制转换浮点参数和缓冲区为double数据类型。此方案适用于input必须为double类型数据。但麻烦点在于需要为定义的每一个隐藏层加.double()。

参考链接:
cnblogs.com/wanghui-garcia/p/11285055.html
Pytorch

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值