Pytorch CrossEntropyloss使用方法(包含多维target)

Pytorch CrossEntropyloss使用方法(包含多维target)

以前都是用tf,最近转来用pytorch。最近博主做一个东西需要用到crossentropyloss,输出是多维输出的。一开始胡乱弄搞出了一个这样子的bug:

RuntimeError: multi-target not supported at在这里插入代码片

然后博主寻求百度,结果发现网上大部分人都只是在照搬例程水个流量,并没有想要的答案。最终还是得科学上网,远赴官方文档找到了使用方法。这里记录一下,也方便以后可能要用的小伙伴。

官方文档给出的用法如下:
在这里插入图片描述
也就是说,在网络的output要把分类放在第二维,第二维后面的代表的是网络的维度,看起来非常简单,博主的示例代码如下:


loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, 2, requires_grad = True)
# target = torch.empty(3, dtype = torch.long).random_(5)
target = torch.empty(3, 2, dtype = torch.long).random_(5)
output = loss(input, target)
  • 9
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值