PyTorch expects the input to a layer to have the same device and data type (dtype) as the parameters of the layer. For most layers, including conv layers, the default data type is torch.float32.
# 如果不添加dtype=torch.fp32会报错,它默认是torch.int64
a = torch.arange(1, 401, dtype=torch.float32).view(8, 5, 10)
net = nn.LSTM(input_size=10,
hidden_size=3,
batch_first=True,
bidirectional=True)
res, _ = net(a)
print(res)
所以以后给网络创建input的tensor时,要注意把dtype指定为torch.float32