pytorch中nn.CrossEntropyLoss使用注意事项

Loss的数学表达公式:

使用代码样例:

# 这样展开就相当于每个词正确的类别和预测的整个词表概率分布进行对应
# ignore_index是指忽略真实标签中的类别
criterion = nn.CrossEntropyLoss(ignore_index=2).to(device) 
vocab_size = pre.shape[-1]
trg = trg[:,1:]
trg_tag = trg.reshape(-1).to(device) # view函数要求在同一个连续地址里,而reshape不用
pre_tag = pre[1:].view(-1,vocab_size).to(device)
loss = criterion(pre_tag,trg_tag)

注意事项:

  1. CrossEntropyLoss实例化之后,其两个输入分别是预测标签和真实标签,顺序不要搞错。预测标签的大小为[N,classnum],真实样本的大小为[N],因为该函数会把真实标签进行one-hot表示。N不一定是batchsize大小,可以对向量进行展开,从而可以逐个样本进行计算loss。

  2. 从公式可以看出,pytorch中的交叉熵loss其本身已使用的一个softmax约束了预测标签输入控制在了0-1之间,所以loss的输入即用模型的输出即可不需要通过softmax后再输入loss中,否则两个softmax可能会导致模型在训练的过程中loss保持不变。

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值