Pytorch CrossEntropyloss使用方法(包含多维target)
以前都是用tf,最近转来用pytorch。最近博主做一个东西需要用到crossentropyloss,输出是多维输出的。一开始胡乱弄搞出了一个这样子的bug:
RuntimeError: multi-target not supported at在这里插入代码片
然后博主寻求百度,结果发现网上大部分人都只是在照搬例程水个流量,并没有想要的答案。最终还是得科学上网,远赴官方文档找到了使用方法。这里记录一下,也方便以后可能要用的小伙伴。
官方文档给出的用法如下:
也就是说,在网络的output要把分类放在第二维,第二维后面的代表的是网络的维度,看起来非常简单,博主的示例代码如下:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, 2, requires_grad = True)
# target = torch.empty(3, dtype = torch.long).random_(5)
target = torch.empty(3, 2, dtype = torch.long).random_(5)
output = loss(input, target)