使用torch_geometric库报错RuntimeError: expected scalar type Float but found Double

参考:(2条消息) RuntimeError: expected scalar type Double but found Float_edward_zcl的博客-CSDN博客https://blog.csdn.net/edward_zcl/article/details/124492199

报错如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-57-0b742325a6a8> in <module>
      2     print(data.edge_index.dtype)
      3     print(data.x.dtype)
----> 4     x = ex(data.x, data.edge_index, data.batch)
      5     break

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

<ipython-input-52-14eff65c6186> in forward(self, x, edge_index, batch)
     52             if i != self.num_gc_layers - 1:
     53 #                 x = torch.tensor(x, dtype=torch.float32)
---> 54                 print(self.convs[i](x, edge_index))
     55                 x = F.relu(self.bns[i](self.convs[i](x, edge_index)))
     56             else:

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch_geometric\nn\conv\gin_conv.py in forward(self, x, edge_index, size)
     78             out += (1 + self.eps) * x_r
     79 
---> 80         return self.nn(out)
     81 
     82     def message(self, x_j: Tensor) -> Tensor:

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    117     def forward(self, input):
    118         for module in self:
--> 119             input = module(input)
    120         return input
    121 

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\modules\linear.py in forward(self, input)
     92 
     93     def forward(self, input: Tensor) -> Tensor:
---> 94         return F.linear(input, self.weight, self.bias)
     95 
     96     def extra_repr(self) -> str:

D:\Download\Anaconda\envs\pytorch\lib\site-packages\torch\nn\functional.py in linear(input, weight, bias)
   1751     if has_torch_function_variadic(input, weight):
   1752         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1753     return torch._C._nn.linear(input, weight, bias)
   1754 
   1755 

RuntimeError: expected scalar type Float but found Double

报错为期待的张量类型是double但是输入的是float,可以将模型所有的层的输入输出类型打印出来

for name, param in model.named_parameters():
    print(name,'-->',param.type(),'-->',param.dtype,'-->',param.shape)

打印输入的数据格式,我使用pyG的Data存储图数据

print(data.edge_index.dtype)
print(data.x.dtype)

打印出来后发现两者数据类型确实不匹配。需要修改data.x和data.edge_index的数据类型以适配。

修改tensor的数据类型示例:

import torch

t = torch.tensor([1,2])
# sample 1
t = torch.tensor(t, dtype=tensor.int64)
# sample 2
t = torch.tensor(t, dtype=tensor.float32)

* pyG要求data.edge_index为int64或long,我一开始用的是double也报错了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值