pytorch中CrossEntropyLoss的使用

CrossEntropyLoss 等价于 softmax+log+NLLLoss

LogSoftmax等价于softmax+log

 

可用于文本分类、序列标注等计算损失

使用方法:

# 首先定义该类
loss = torch.nn.CrossEntropyLoss()
#然后传参进去
loss(input, target)

input维度为N*C,是网络生成的值,N为batch_size,C为类别数;

target维度为N,是标注值,非one-hot类型的值;

input = torch.randn(4,3)
target = torch.tensor([0,1,1,2])   #必须为Long类型,是类别的序号
cross_entropy_loss = nn.CrossEntropyLoss()
loss = cross_entropy_loss(input, target)


# 对于序列标注来说,需要reshape一下
input = torch.randn(2,4,3)  # 2为batch_size, 4为seq_length,3为类别数
input = input.view(-1,3)    # 一共8个token
target = torch.tensor([[0,1,1,2], [2,3,1,0]])
target = target.view(-1)
loss = cross_entropy_loss(input, target)  # reduction='mean',默认为mean;
  • 6
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值