RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.
代码如下:
>>> import torch
>>> a=torch.Tensor([1])
>>> b=torch.Tensor([2])
>>> a.requires_grad
False
>>> a.requires_grad=True
>>> c=a+b
>>> c
tensor([3.], grad_fn=<ThAddBackward>)
>>> c.requires_grad
True
>>> b=b.type(torch.cuda.FloatTensor)
>>> b
tensor([2.], device='cuda:0')
>>> c=a+b
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #3 'other'
解决方法:
>>> b=b.cpu()
>>> c=a+b
>>> c
tensor([3.], grad_fn=<ThAddBackward>)
>>> c.requires_grad
True