最近在学习自然语言处理,到张量运算了,记录出错问题:
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long
源代码
import torch
x=torch.tensor([[1,2,3],[4,5,6]])
x.mean()
查找修改后
import torch
x=torch.tensor([[1.0,2,3],[4,5,6]])
x.mean()
或者
import torch
x=torch.Tensor([[1,2,3],[4,5,6]])
x.mean()
原因:
input:输入张量。它的数据类型必须是浮点型或复数型。对于复数的输入,范数使用每个元素的绝对值。注意,输入张量中元素的数据类型一定得是浮点型或者是复数。这就是报错原因。