关于pytorch中关于nn.module的type问题。

RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same.

如题所示的错误为input的type时double的而weight的type时float的(weight的type在pytorch中默认为float),因此遇到这种问题一般有两种解决方案

1.将input的type由double更换为float
示例:

input = Variable(data)
input = input.float()

此种方案是将input的double更改为float毫无疑问会对矩阵内部数值产生影响。因此若是你的矩阵数值大到必须使用double那这个方案需要pass。若是你的矩阵数值使用float对矩阵本身没什么影响则可使用这条命令。

2.将nn.module的type由float更换为double
示例:

linear = nn.Linear(2, 2)
print(linear.weight)

linear.double()
print(linear.weight)
Parameter containing:
tensor([[-0.6322, -0.2967],
        [-0.5611,  0.0638]], requires_grad=True)
Parameter containing:
tensor([[-0.6322, -0.2967],
        [-0.5611,  0.0638]], dtype=torch.float64, requires_grad=True)

conv = nn.Conv2d(2, 64, 3, padding=1).double()
print(conv.weight)
Parameter containing:
tensor([[[[-0.2335, -0.0215,  0.0710],
          [-0.0369,  0.2331, -0.1877],
          [-0.1738,  0.1258, -0.0202]],

         [[-0.2055,  0.0458, -0.1719],
          [ 0.1927,  0.1866, -0.0046],
          [ 0.1286,  0.1005, -0.0137]]],


        [[[ 0.0427,  0.1982,  0.0761],
          [-0.2082, -0.1114,  0.0637],
          [ 0.1352,  0.0848,  0.1489]],

         [[ 0.0228,  0.2221, -0.1518],
          [-0.0462,  0.1292, -0.1872],
          [-0.0874,  0.0772,  0.1837]]],


        [[[ 0.2184,  0.2230, -0.0358],
          [-0.2187,  0.1524,  0.0409],
          [ 0.2244,  0.1070, -0.0077]],

         [[ 0.1362, -0.1353, -0.0493],
          [-0.0381, -0.0550, -0.0046],
          [ 0.1634,  0.1359, -0.1857]]],


        ...,


        [[[-0.2268, -0.1568,  0.0236],
          [ 0.0697, -0.0852,  0.0374],
          [-0.0336, -0.0743, -0.0871]],

         [[ 0.0218, -0.0523,  0.1762],
          [-0.2168, -0.1741,  0.1094],
          [ 0.0216,  0.1672, -0.1006]]],


        [[[-0.1174,  0.1897, -0.1503],
          [-0.1382, -0.1979,  0.1617],
          [ 0.2282,  0.2337,  0.1170]],

         [[-0.1433, -0.2325,  0.1023],
          [ 0.2197, -0.1562, -0.1585],
          [-0.0422,  0.0011,  0.0161]]],


        [[[ 0.1030,  0.1959, -0.1112],
          [-0.0733, -0.0093,  0.1591],
          [ 0.1522,  0.0894,  0.1800]],

         [[ 0.0505,  0.1076, -0.0392],
          [ 0.1726,  0.0952, -0.1742],
          [-0.2179,  0.2116,  0.0131]]]], dtype=torch.float64,
       requires_grad=True)

如代码所示,将你定义conv层或者maxpooling、liner层加后缀.double()。则自动强制转换浮点参数和缓冲区为double数据类型。此方案适用于input必须为double类型数据。但麻烦点在于需要为定义的每一个隐藏层加.double()。

参考链接:
cnblogs.com/wanghui-garcia/p/11285055.html
Pytorch

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
根据提供的引用内容,torch.nn.modules.module.Module类的__getattr__方法用于获取给定name的成员。它首先从self.__dict__['_parameters']、self.__dict__['_buffers']以及self.__dict__['_modules']查找,找到后返回该成员;如果找不到,则会引发AttributeError异常,报错信息为"'{}' object has no attribute '{}'".format(type(self).__name__, name)。 在这种情况下,错误信息提示是"torch.nn.modules.module.ModuleAttributeError: 'DataParallel' object has no attribute 'get_params'"。这表示DataParallel对象没有名为'get_params'的属性。这个错误通常发生在尝试访问一个不存在的属性时。可能是代码使用了错误的属性名或者没有正确设置属性。要解决这个错误,你可以检查代码的属性名是否正确拼写,并确保属性已正确设置。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [in <module>报错_【PyTorch】torch.nn.Module 源码分析](https://blog.csdn.net/weixin_39851809/article/details/110105784)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [undefined](undefined)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值