PyTorch中Cross Entropy Loss的用法与背景

重点提示

注意,PyTorch的Cross Entropy Loss与其它框架的不同,因为PyTorch中该损失函数其实自带了“nn.LogSoftmax”与“nn.NLLLoss”两个方法。因此,在PyTorch的Cross Entropy Loss之前请勿再使用Softmax方法!

使用场景

当现在面临多分类问题(不限于二分类问题)需要Loss函数时,Cross Entropy Loss是一个很方便的工具。

公式

loss ( x , c l a s s ) = − log ⁡ ( exp ⁡ ( x [ c l a s s ] ) ∑ j exp ⁡ ( x [ j ] ) ) = − x [ c l a s s ] + log ⁡ ( ∑ j exp ⁡ ( x [ j ] ) ) \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right) = -x[class] + \log\left(\sum_j \exp(x[j])\right) loss(x,class)=log(jexp(x[j])exp(x[class]))=x[class]+log(jexp(x[j]))
其中class为样本x的类别,x[class]为样本类别对应的预测分数;对于中间的公式,核心思想就是计算样本目标类别分值的softmax后取负对数作为分类损失。

数学背景

设 s = exp ⁡ ( x [ c l a s s ] ) ∑ j exp ⁡ ( x [ j ] ) 设s = \frac{\exp(x[class])}{\sum_j \exp(x[j])} s=jexp(x[j])exp(x[class])
易知s其实是分类正确的概率,介于0~1之间;对照着log函数的图像可以看出,当s=1时(此时分类完全正确),取完对数为0;当s接近0时(分类完全错误),取完对数为负无穷,取复数后变为正无穷;恰好可以充当loss函数:分类正确时损失为0,分类越错误损失越大。

除此之外,可以看出log函数越接近0,梯度越大;log函数越接近1,梯度越小。因此在更新参数时,当网络分类错的很离谱(loss较大时),求导后会得到比较大的梯度,从而大幅更新网络参数;随着网络正确率的升高,梯度也会逐渐平缓,渐渐进入“微调阶段”。
在这里插入图片描述

用法

loss = torch.nn.CrossEntropyLoss()
output = loss(input, target)
output.backward()

其中,input为样本的预测结果矩阵,形状为(样本数量,类别数量),例如100个样本实现二分类形状就是(100,2),每一列的index分别表示对应的类别;target为标签向量,形状为(样本数量),其中为各样本对应的类别index。
假如是二分类,对应的最理想(完全正确)的预测结果应如下所示:

target:1,0,1

input:

0(类)1(类)
01
10
01

target向量的长度为3,说明现在有三个样本,其中第一个、第三个样本的标签均为1,第二个样本的标签为0;而input矩阵的形状为(3,2),每行为对应样本的预测结果分值,而每列为对应类别的分值。当我们希望获得当前样本被分类为1的分值时,我们取第1列,对应的向量为(1,0,1);当我们希望获得当前样本被分类为0的分值时,我们取第0列,对应的向量为(0,1,0)

实际预测中,很少能达到这么完美的情况,加上CrossEntropyLoss一般与Softmax连用,因此input矩阵中的每个元素表示的其实是第i个样本(i行)被分类为j类(j列)的概率

我们以二分类为例,如下所示:
input:

0(类)1(类)
0.30.7
0.60.4
0.20.8

该input矩阵表示的其实是第1个样本被分类到1类的概率是0.7,第2个样本被分类到1类的概率是0.4,第3个样本被分类到1类的概率是0.8。

  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值