1.问题
在运行PyTorch模型脚本训练分类网络时,出现如下报错:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'
出错代码为:
loss_function = torch.nn.CrossEntropyLoss()
loss = loss_function (out, label)
备注:CrossEntropyLoss的源码里“label”参数使用的dtype是“long”类型。
2.解决办法
修改的时候,直接在label后面加上.long(),如下所示:
loss_function = torch.nn.CrossEntropyLoss()
loss = loss_function (out, label.long())