Pytorch里的CrossEntropyLoss详解

转自: Pytorch里的CrossEntropyLoss详解_aiwanghuan5017的博客-CSDN博客

首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?

下面是对与cross entropy有关的函数做的总结:

torch.nntorch.nn.functional (F)
CrossEntropyLosscross_entropy
LogSoftmaxlog_softmax
NLLLossnll_loss

下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。

在介绍cross_entropy之前先介绍两个基本函数:


log_softmax

这个很好理解,其实就是logsoftmax合并在一起执行。


nll_loss

该函数的全程是negative log likelihood loss,函数表达式为

f(x,class)=−x[class]f(x,class)=−x[class]

例如假设x=[1,2,3],class=2x=[1,2,3],class=2,那额f(x,class)=−x[2]=−3f(x,class)=−x[2]=−3


cross_entropy

交叉熵的计算公式为:

cross_entropy=−∑k=1N(pk∗logqk)cross_entropy=−∑k=1N(pk∗log⁡qk)

其中pp表示真实值,在这个公式中是one-hot形式;qq是预测值,在这里假设已经是经过softmax后的结果了。

仔细观察可以知道,因为pp的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:

cross_entropy=−∑k=1N(pk∗logqk)=−logqmcross_entropy=−∑k=1N(pk∗log⁡qk)=−logqm

仔细看看,是不是就是等同于log_softmaxnll_loss两个步骤。

所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmaxnll_loss来计算交叉熵,其计算方式如下:

loss(x,class)=−log(exp(x[class])∑jexp(x[j]))loss⁡(x,class)=−log⁡(exp⁡(x[class])∑jexp⁡(x[j]))

代码示例

 
  1. >>> input = torch.randn(3, 5, requires_grad=True)

  2. >>> target = torch.randint(5, (3,), dtype=torch.int64)

  3. >>> loss = F.cross_entropy(input, target)

  4. >>> loss.backward()

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值