下面是我定义的一个函数
def run(x):
torch.set_default_tensor_type(‘torch.DoubleTensor’)
x=np.array(x)
a=torch.tensor(2).type(torch.DoubleTensor)
x=x.reshape(2,200)
x=torch.DoubleTensor(np.array(x))
lstm = torch.load(‘lstm.pkl’)
x=torch.unsqueeze(x, dim=2)
output=lstm(x)
运行时在lstm(x)这里出错:RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm
这时候我们需要把x的数据类型进行修改
只需要加上一行 x = torch.tensor(x, dtype=torch.float32)就可以了。