分类问题为什么用交叉熵损失不用 MSE 损失

  • 本文说明以下问题

    1. MSE 损失主要适用与回归问题,因为优化 MSE 等价于对高斯分布模型做极大似然估计,而简单回归中做服从高斯分布的假设是比较合理的
    2. 交叉熵损失主要适用于多分类问题,因为优化交叉熵损失等价于对多项式分布模型做极大似然估计,而多分类问题通常服从多项式分布

    事实上,最大似然估计往往将损失建模为负对数似然,这样的损失一定等价于定义在训练集上的经验分布和定义在模型上的概率分布间的交叉熵,这个交叉熵根据模型定义有时可以转化为不同的损失,这块可以参考:信息论概念详细梳理:信息量、信息熵、条件熵、互信息、交叉熵、KL散度、JS散度 4.1 节

  • 先明确本文讨论的多分类问题的符号:

    1. 训练样本集大小为 N N N
    2. 类别数为 K K K
    3. i i i 个样本 x ( i ) \pmb{x}^{(i)} x(i) 的真实标记概率分布为 y ( i ) = { y 1 ( i ) , y 2 ( i ) , . . . , y K ( i ) } \pmb{y}^{(i)}=\{y_1^{(i)},y_2^{(i)},...,y_K^{(i)}\} y(i)={y1(i),y2(i),...,yK(i)},事实上这是一个 one-hot 向量
    4. i i i 个样本 x ( i ) \pmb{x}^{(i)} x(i) 的预测标记概率分布为 y ^ ( i ) = { y ^ 1 ( i ) , y ^ 2 ( i ) , . . . , y ^ K ( i ) } \pmb{\hat{y}}^{(i)}=\{\hat{y}_1^{(i)},\hat{y}_2^{(i)},...,\hat{y}_K^{(i)}\} y^(i)={y^1(i),y^2(i),...,y^K(i)}

    这种情况下,MSE 损失和交叉熵损失分别为

    1. MSE 损失: L = 1 N ∑ i N ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 = 1 N ∑ i = 1 N ∑ k = 1 K ( y k ( i ) − y ^ k ( i ) ) 2 L = \frac{1}{N}\sum_{i}^N ||\pmb{y}^{(i)}-\pmb{\hat{y}}^{(i)}||^2 = \frac{1}{N}\sum_{i=1}^N\sum_{k=1}^K(y_k^{(i)}-\hat{y}_k^{(i)})^2 L=N1iN∣∣y(i)y^(i)2=N1i=1Nk=1K(yk(i)y^k(i))2
    2. 交叉熵损失: L = − 1 N ∑ i = 1 N ∑ k = 1 K y k ( i ) l o g y ^ k ( i ) L=-\frac{1}{N}\sum_{i=1}^N\sum_{k=1}^Ky_k^{(i)}log\hat{y}_k^{(i)} L=N1i=1Nk=1Kyk(i)logy^k(i)

1. 概率角度

1.1 优化 MSE 损失等价于高斯分布的最大似然估计

  • 我们可以把第 i i i 个样本 x ( i ) \pmb{x}^{(i)} x(i) 的真实标记值 y ( i ) \pmb{y}^{(i)} y(i) 看做预测标记值 y ^ ( i ) \hat{\pmb{y}}^{(i)} y^(i) 加上噪音误差 e ( i ) \pmb{e}^{(i)} e(i) 所得,假设误差 e ( i ) ∼ N ( 0 , B ) \pmb{e}^{(i)}\sim N(0,\pmb{B}) e(i)N(0,B) 服从期望为 0,协方差矩阵为 B \pmb{B} B K K K 维高斯分布,则样本的真实标签 y ( i ) = y ^ ( i ) + e ( i ) = f ( x ( i ) , w ) + e ( i ) ∼ N ( f ( x ( i ) , w ) , B ) \pmb{y}^{(i)} = \hat{\pmb{y}}^{(i)}+\pmb{e}^{(i)} = f(\pmb{x}^{(i)},\pmb{w})+\pmb{e}^{(i)} \sim N(f(\pmb{x}^{(i)},\pmb{w}),\pmb{B}) y(i)=y^(i)+e(i)=f(x(i),w)+e(i)N(f(x(i),w),B) 也服从期望为 y ^ ( i ) = f ( x ( i ) , w ) \pmb{\hat{y}}^{(i)} = f(\pmb{x}^{(i)},\pmb{w}) y^(i)=f(x(i),w),协方差矩阵为 B \pmb{B} B K K K 维高斯分布,有
    p ( y ( i ) ∣ x ( i ) , w ) = 1 ( 2 π ) n / 2 ∣ B ∣ 1 / 2 e − 1 2 △ ( i ) p(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) = \frac{1}{(2\pi)^{n/2}|\pmb{B}|^{1/2}}e^{-\frac{1}{2}\triangle^{(i)}} p(y(i)x(i),w)=(2π)n/2B1/21e21(i) 其中 △ ( i ) = [ y ( i ) − y ^ ( i ) ] ⊤ B − 1 [ y ( i ) − y ^ ( i ) ] \triangle^{(i)} = [\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}]^\top\pmb{B}^{-1}[\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}] (i)=[y(i)y^(i)]B1[y(i)y^(i)]

  • 由于样本独立同分布,整个样本集的似然函数为
    L ( w ) = p ( y ( 1 ) ∣ x ( 1 ) , w ) p ( y ( 2 ) ∣ x ( 2 ) , w ) . . . . p ( y ( N ) ∣ x ( N ) , w ) = ∏ i = 1 N p ( y ( i ) ∣ x ( i ) , w ) L(\pmb{w}) =p(\pmb{y}^{(1)}|\pmb{x}^{(1)},\pmb{w})p(\pmb{y}^{(2)}|\pmb{x}^{(2)},\pmb{w})....p(\pmb{y}^{(N)}|\pmb{x}^{(N)},\pmb{w}) = \prod\limits_{i=1}^Np(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) L(w)=p(y(1)x(1),w)p(y(2)x(2),w)....p(y(N)x(N),w)=i=1Np(y(i)x(i),w) 通过最大化对数似然函数的方式得到参数 w \pmb{w} w 的估计值,即
    w ^ = arg max ⁡ w ^ L ( w ) = arg max ⁡ w ^ ∏ i = 1 N p ( y ( i ) ∣ x ( i ) , w ) = arg max ⁡ w ^ ∑ i = 1 N log ⁡ p ( y ( i ) ∣ x ( i ) , w ) = arg max ⁡ w ^ ∑ i = 1 N ( log ⁡ e − 1 2 △ ( i ) ) = arg max ⁡ w ^ ∑ i = 1 N − △ ( i ) = arg min ⁡ w ^ ∑ i = 1 N △ ( i ) = arg min ⁡ w ^ ∑ i = 1 N [ y ( i ) − y ^ ( i ) ] ⊤ B − 1 [ y ( i ) − y ^ ( i ) ] \begin{aligned} \hat{\pmb{w}} &= \argmax\limits_{\mathbf{\hat{w}}} L(\pmb{w}) \\ &= \argmax\limits_{\mathbf{\hat{w}}}\prod_{i=1}^Np(y^{(i)}|\pmb{x}^{(i)},\pmb{w}) \\ &= \argmax\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N\log p(y^{(i)}|\pmb{x}^{(i)},\pmb{w}) \\ &= \argmax\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N (\log e^{-\frac{1}{2}\triangle^{(i)}})\\ &= \argmax\limits_{\mathbf{\hat{w}}} \sum_{i=1}^N-\triangle^{(i)} \\ &= \argmin\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N \triangle^{(i)} \\ &= \argmin\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N [\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}]^\top\pmb{B}^{-1}[\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}] \end{aligned} w^=w^argmaxL(w)=w^argmaxi=1Np(y(i)x(i),w)=w^argmaxi=1Nlogp(y(i)x(i),w)=w^argmaxi=1N(loge21(i))=w^argmaxi=1N(i)=w^argmini=1N(i)=w^argmini=1N[y(i)y^(i)]B1[y(i)y^(i)]

  • 这时考虑特殊情况

    1. K K K 维正态分布的各个维度相互独立且各项同性时,协方差矩阵变为单位矩阵,有
      △ ( i ) = ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 \triangle^{(i)} = ||\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}||^2 (i)=∣∣y(i)y^(i)2 这时最大似然估计的结果为
      w ^ = arg min ⁡ w ^ ∑ i = 1 N ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 \hat{\pmb{w}} = \argmin\limits_{\mathbf{\hat{w}}} \sum_{i=1}^N||\pmb{y}^{(i)}-\hat{\pmb{y}}^{(i)}||^2 w^=w^argmini=1N∣∣y(i)y^(i)2 这和最小化 MSE 损失的优化目标一致
    2. 进一步特殊化,当 K = 1 K=1 K=1 时, B = 1 \pmb{B}=1 B=1,退化到一元线性回归情况
  • 我们认为,误差是由于随机的、无数的、独立的、多个因素造成的,因此根据中心极限定理,预测误差在大样本量的情况下确实服从正态分布,因此优化 MSE 损失等价于对高斯分布做最大似然估计

  • 注意到,只有当样本标记服从多个维度相互独立且各项同性的多维高斯分布时,优化 MSE 损失才等价于做最大似然估计,而这往往很难达成,这也是 MSE 不适用于多分类问题而适用于一元线性回归的原因之一(后者自动满足这个条件)

1.2 优化交叉熵损失等价于多项式分布的最大似然

1.2.1 多项式分布

  • 我们借助从伯努利分布到二项分布的变形,从 categorical 分布导出多项式分布
  • 先看熟悉的伯努利分布:抛硬币正面朝上的概率为 θ \theta θ,这时抛 1 次硬币,出现正面次数 X = m ∈ { 0 , 1 } X=m\in\{0,1\} X=m{0,1} 的概率服从伯努利分布
    P ( X = m ∣ θ ) = θ m ( 1 − θ ) 1 − m ,   m ∈ { 0 , 1 } P(X=m|\theta) = \theta^m(1-\theta)^{1-m},\space m\in\{0,1\} P(X=mθ)=θm(1θ)1m, m{0,1} 再看二项分布:抛硬币正面朝上概率为 θ \theta θ ,抛 n n n 次硬币,出现正面次数 X = m X=m X=m 的概率服从二项分布
    P ( X = m ∣ θ , n ) = C n m θ m ( 1 − θ ) n − m P(X=m|\theta,n) = C_n^m\theta^m(1-\theta)^{n-m} P(X=mθ,n)=Cnmθm(1θ)nm
  • categorical 分布可以类比伯努利分布: K K K 面骰子,每一面出现概率分别为 θ 1 , θ 2 , . . . , θ K \theta_1,\theta_2,...,\theta_K θ1,θ2,...,θK,抛 1 次骰子,第 p p p 面出现次数 X = m p ∈ { 0 , 1 } X=m_p\in\{0,1\} X=mp{0,1} 的概率服从 categorical 分布(下式 ∑ k = 1 K θ i = 1 \sum_{k=1}^K\theta_i=1 k=1Kθi=1 m k ∈ { 0 , 1 } m_k\in\{0,1\} mk{0,1} ∑ k = 1 K m k = 1 \sum_{k=1}^Km_k=1 k=1Kmk=1
    P ( X = m p ∣ θ 1 , θ 2 , . . . , θ K ) = ∏ k = 1 K θ k m k P(X=m_p|\theta_1,\theta_2,...,\theta_K) = \prod_{k=1}^K \theta_k^{m_k} P(X=mpθ1,θ2,...,θK)=k=1Kθkmk 多项式分布可以类比二项分布: K K K 面骰子,每一面出现概率分别为 θ 1 , θ 2 , . . . , θ K \theta_1,\theta_2,...,\theta_K θ1,θ2,...,θK,抛 N N N 次骰子,第 1 面到第 k k k 面出现次数为 X 1 = m 1 , X 2 = m 2 , . . . , X K = m K X_1=m_1,X_2=m_2,...,X_K=m_K X1=m1,X2=m2,...,XK=mK 的概率服从多项分布 (下式 ∑ k = 1 K θ k = 1 \sum_{k=1}^K\theta_k=1 k=1Kθk=1 ∑ k = 1 K m k = n \sum_{k=1}^Km_k=n k=1Kmk=n
    P ( X 1 = m 1 , X 2 = m 2 , . . . , X K = m K ∣ θ 1 , θ 2 , . . . , θ K , N ) = C N m 1 θ 1 m 1 C N − m 1 m 2 θ 2 m 2 . . . C N − m 1 − m 2 − . . . − m k − 1 m K θ K m K = N ! m 1 ! m 2 ! . . . m K ! ∏ k = 1 K θ k m K \begin{aligned} P(X_1=m_1,X_2=m_2,...,X_K=m_K|\theta_1,\theta_2,...,\theta_K,N) &= C_N^{m_1}\theta_1^{m_1}C_{N-m_1}^{m_2}\theta_2^{m_2}...C_{N-m_1-m_2-...-m_{k-1}}^{m_K}\theta_K^{m_K}\\ &= \frac{N!}{m_1!m_2!...m_K!}\prod_{k=1}^K\theta_k^{m_K} \end{aligned} P(X1=m1,X2=m2,...,XK=mKθ1,θ2,...,θK,N)=CNm1θ1m1CNm1m2θ2m2...CNm1m2...mk1mKθKmK=m1!m2!...mK!N!k=1KθkmK
  • 对于多分类问题来说,可以把总类别数看做这里的 K K K,把各个类别的预测概率(模型输出概率)看做这里的 θ \theta θ,把总样本数看做这里的 N N N,把样本真实标记分布(one-hot 向量)看做这里的 m m m,则多分类问题的样本集可以看做服从多项式分布,上式可改写为
    N ! m 1 ! m 2 ! . . . m K ! ∏ k = 1 K ( y ^ k ( i ) ) ∑ i y k ( i ) = N ! m 1 ! m 2 ! . . . m K ! ∏ i = 1 N ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) \begin{aligned} \frac{N!}{m_1!m_2!...m_K!}\prod_{k=1}^K(\hat{y}_k^{(i)})^{\sum_i y_k^{(i)}} = \frac{N!}{m_1!m_2!...m_K!} \prod_{i=1}^N\prod_{k=1}^K(\hat{y}_k^{(i)})^{y_k^{(i)}} \end{aligned} m1!m2!...mK!N!k=1K(y^k(i))iyk(i)=m1!m2!...mK!N!i=1Nk=1K(y^k(i))yk(i) 前面的 N ! m 1 ! m 2 ! . . . m K ! \frac{N!}{m_1!m_2!...m_K!} m1!m2!...mK!N! 是归一化系数,目的是使总和等于 1 以满足概率形式,它是个常数,并不重要
  • 如果对分布的形式还不是很清楚,可以看这个例子

    在这里插入图片描述

1.2.2 优化交叉熵损失等价于多项式分布的最大似然

  • 根据 categorical 分布,第 i i i 个样本 x i \pmb{x}_i xi 真实标记 y ( i ) \pmb{y}^{(i)} y(i) 出现概率为
    p ( y ( i ) ∣ x ( i ) , w ) = ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) p(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) = \prod_{k=1}^K (\hat{y}_k^{(i)})^{y_k^{(i)}} p(y(i)x(i),w)=k=1K(y^k(i))yk(i) 注意这里是 y ( i ) \pmb{y}^{(i)} y(i) 是 one-hot 向量, y k y_k yk 中只有一个值为 1,其他都是 0
  • 似然函数为
    L ( w ) = ∏ i = 1 N p ( y ( i ) ∣ x ( i ) , w ) = ∏ i = 1 N ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) L(\pmb{w}) = \prod_{i=1}^Np(\pmb{y}^{(i)}|\pmb{x}^{(i)},\pmb{w}) = \prod_{i=1}^N\prod_{k=1}^K (\hat{y}_k^{(i)})^{y_k^{(i)}} L(w)=i=1Np(y(i)x(i),w)=i=1Nk=1K(y^k(i))yk(i) 通过最大化对数似然函数的方式得到参数 w \pmb{w} w 的估计值,即
    w ^ = arg max ⁡ w ^ ∏ i = 1 N ∏ k = 1 K ( y ^ k ( i ) ) y k ( i ) = arg max ⁡ w ^ ∑ i = 1 N ∑ k = 1 K y k ( i ) log ⁡ y ^ k ( i ) = arg min ⁡ w ^ − ∑ i = 1 N ∑ k = 1 K y k ( i ) log ⁡ y ^ k ( i ) \begin{aligned} \hat{\pmb{w}} &= \argmax\limits_{\mathbf{\hat{w}}}\prod_{i=1}^N\prod_{k=1}^K (\hat{y}_k^{(i)})^{y_k^{(i)}}\\ &= \argmax\limits_{\mathbf{\hat{w}}}\sum_{i=1}^N\sum_{k=1}^K {y_k^{(i)}}\log \hat{y}_k^{(i)}\\ &= \argmin\limits_{\mathbf{\hat{w}}} - \sum_{i=1}^N\sum_{k=1}^K {y_k^{(i)}}\log \hat{y}_k^{(i)}\\ \end{aligned} w^=w^argmaxi=1Nk=1K(y^k(i))yk(i)=w^argmaxi=1Nk=1Kyk(i)logy^k(i)=w^argmini=1Nk=1Kyk(i)logy^k(i) 这和最小化交叉熵损失的优化目标一致

2. 梯度角度

  • 从梯度角度看,如果使用了 sigmoid 或类似形状的激活函数,对于本文开头提出的多分类问题,在计算梯度时
    1. MSE 损失,参数梯度关于绝对误差是一个凹函数形式,导致更新强度和绝对误差值不成正比,优化过程低效
    2. 交叉熵损失,参数梯度关于绝对误差是线性函数形式,更新强度和绝对误差值成正比,优化过程高效稳定
  • 具体推导请参考:为什么使用交叉熵作为损失函数

3. 直观角度

  • 回过头看两个损失函数
    1. MSE 损失: L = 1 N ∑ i N ∣ ∣ y ( i ) − y ^ ( i ) ∣ ∣ 2 = 1 N ∑ i = 1 N ∑ k = 1 K ( y k ( i ) − y ^ k ( i ) ) 2 L = \frac{1}{N}\sum_{i}^N ||\pmb{y}^{(i)}-\pmb{\hat{y}}^{(i)}||^2 = \frac{1}{N}\sum_{i=1}^N\sum_{k=1}^K(y_k^{(i)}-\hat{y}_k^{(i)})^2 L=N1iN∣∣y(i)y^(i)2=N1i=1Nk=1K(yk(i)y^k(i))2
    2. 交叉熵损失: L = − 1 N ∑ i = 1 N ∑ k = 1 K y k ( i ) l o g y ^ k ( i ) L=-\frac{1}{N}\sum_{i=1}^N\sum_{k=1}^Ky_k^{(i)}log\hat{y}_k^{(i)} L=N1i=1Nk=1Kyk(i)logy^k(i) 由于 y ( i ) \pmb{y}^{(i)} y(i) 是 one-hot 向量,假设 k i k_i ki 是第 i i i 个样本标记类别,可以进一步化简为 L = − 1 N ∑ i = 1 N l o g y ^ k i ( i ) L=-\frac{1}{N}\sum_{i=1}^Nlog\hat{y}_{k_i}^{(i)} L=N1i=1Nlogy^ki(i)
  • 可见:MSE无差别地关注全部类别上预测概率和真实概率的差交叉熵关注的是正确类别的预测概率
    1. 如果真实标签是 ( 1 , 0 , 0 ) (1, 0, 0) (1,0,0),模型1的预测标签是 ( 0.8 , 0.2 , 0 ) (0.8, 0.2, 0) (0.8,0.2,0),模型2的是 ( 0.8 , 0.1 , 0.1 ) (0.8, 0.1, 0.1) (0.8,0.1,0.1),那么MSE-based 认为模型2更好;交叉熵-based认为一样。从最终预测的类别上看,模型1和模型2的真实输出其实是一样的
    2. 再换个角度,MSE对残差大的样例惩罚更大些。比如真实标签分别是 ( 1 , 0 , 0 ) (1, 0, 0) (1,0,0),模型1的预测标签是 ( 0.8 , 0.2 , 0 ) (0.8, 0.2, 0) (0.8,0.2,0),模型2的是 ( 0.9 , 0.1 , 0 ) (0.9, 0.1, 0) (0.9,0.1,0),即使输出的标签都是类别0, 但 MSE-based 算出来模型1的误差是模型2的4倍;而交叉熵-based算出来模型1的误差是模型2的2倍左右。为了弥补模型1在这个样例上的损失,MSE-based需要3个完美预测的样例才能达到和模型2一样的损失,而交叉熵-based只需要一个。实际上,模型输出正确的类别,0.8可能已经是个不错的概率了.
  • 本段参考:MSE vs 交叉熵
  • 16
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云端FFF

所有博文免费阅读,求打赏鼓励~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值