比如报错为:RuntimeError: expected scalar type Long but found Float
,就是希望输入为torch.Long
,结果得到一个torch.Float
解决方法
转换类型即可,更多方法请参考:https://blog.csdn.net/weixin_35757704/article/details/118378709
import torch
x = torch.FloatTensor([1, 2, 3, 4, 5, 6, 8])
x = x.type(torch.LongTensor)
print(x.type()) # torch.LongTensor