Pytorch 手工复现交叉熵损失(Cross Entropy Loss)

如果直接调包的话很简单,例子如下:

import torch
import torch.nn as nn

torch.manual_seed(1234)
ce_loss = nn.CrossEntropyLoss()
x_input = torch.randn(3,3)
y_target = torch.tensor([1,2,0])
print(ce_loss(x_input, y_target))

这里的x_input可以理解为我们的网络预测结果,而y_target为真值,此处例子的结果是0.8526。那么自己做的话,考虑交叉熵的公式如下: L = 1 N ∑ i L i = − 1 N ∑ i ∑ c = 1 M y i c log ⁡ ( p i c ) L=\frac{1}{N} \sum_{i} L_{i}=-\frac{1}{N} \sum_{i} \sum_{c=1}^{M} y_{i c} \log \left(p_{i c}\right) L=N1iLi=N1ic=1Myiclog(pic) 其中 M M M为总类别数, y i c y_{ic} yic为符号函数(只能为0或1), p i c p_{ic} pic为预测得到的概率(即此处的网络预测结果x_input)。

首先为了方便"矩阵运算",我们将y给展平为列向量的形式:

y_target = y_target.view(-1, 1)

对于x,将其做softmax压缩至0-1范围内后再进行log运算:

x_input = F.log_softmax(x_input, 1)

接着,利用pytorch的gather函数,找到各标签y所对应的x:

x_input.gather(1, y_target)

这一步是相对最难理解的。回到公式,考虑到我们已经算完了 log ⁡ ( p i c ) \log (p_{i c}) log(pic),现在要做的也就是找到x_input中各行所对应的独热的值。例如,这里y_target为:

tensor([[1],
        [2],
        [0]])

意思就是对于x的第一行,取第1个值;对于x的第二行,取第2个值;对于x的第三行,取第0个值。因为x为:

tensor([[-1.0207, -0.6645, -2.0784],
        [-1.0184, -1.8474, -0.7315],
        [-1.1617, -0.6996, -1.6595]])

因此取到的结果为:

tensor([[-0.6645],
        [-0.7315],
        [-1.1617]])

这一过程恰好是可以通过torch中的gather方法解决的。

最后求均值得到最终结果:

res = -1 * res
print(res.mean())

可以发现结果也为0.8526,完整代码如下:

import torch
import torch.nn.functional as F

torch.manual_seed(1234)
x_input = torch.randn(3,3)
y_target = torch.tensor([1,2,0])
y_target = y_target.view(-1, 1)
x_input = F.log_softmax(x_input, 1)
res = x_input.gather(1, y_target)
res = -1 * res
print(res.mean())

参考:
https://zhuanlan.zhihu.com/p/98785902
https://www.jianshu.com/p/0c159cdd9c50

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值