手搓GPT系列之 - Logistic Regression模型,Softmax模型的损失函数与CrossEntropyLoss的关系

文章探讨了LogisticRegression和Softmax模型的目标函数与PyTorch中的CrossEntropyLoss之间的关系。CrossEntropyLoss实际上是这两个模型损失函数的通用化表达,适用于多分类情况。文中通过公式推导说明了LR模型和Softmax模型的损失函数如何转化为交叉熵形式,并指出交叉熵在处理标签不唯一情况的能力。
摘要由CSDN通过智能技术生成

笔者在学习各种分类模型和损失函数的时候发现了一个问题,类似于Logistic Regression模型和Softmax模型,目标函数都是根据最大似然公式推出来的,但是在使用pytorch进行编码的时候,却发现根本就没有提供softmax之类的损失函数,而提供了CrossEntropyLoss,MSELoss之类的。本文将介绍我们在学习LR模型和Softmax模型的时候接触到的目标函数,与实际应用中的经常用到的CrossEntropyLoss函数之间的关系。

弄懂了这个关系之后,笔者突然发现以前的一篇介绍LR模型和softmax模型基础的文章里存在一个十分傻的bug。本着线上有bug偷偷改,文章有bug坚决不改,不但不改还要四处宣扬的游街示众要不然怎么记得住的原则,笔者打算让那个bug保留在文章里,请各位朋友到评论区帮笔者找找这个bug吧。出bug的文章在这里:浅谈线性回归与softmax分类器

1. 交叉熵函数(Cross Entropy)

对于一个训练样本集,我们可以把损失函数理解为一个关于训练数据的模型输出 a a a,与该样本的标签 a ˙ \dot{a} a˙的函数,标记为 L ( a , a ˙ ) L(a,\dot{a}) L(a,a˙),该函数用于计算所有训练样本的 a a a值和 a ˙ \dot{a} a˙值之间的关系,当 a a a值和 a ˙ \dot{a} a˙值越接近, L ( a , a ˙ ) L(a,\dot{a}) L(a,a˙)越小,反之 L ( a , a ˙ ) L(a,\dot{a}) L(a,a˙)值越大。很多情况下,交叉熵公式(Cross Entropy)是一个很好的选择。这里写出交叉熵公式:
C r o s s E n t r o p y ( a , a ˙ ) = − ∑ a ˙ ⋅ l o g ( a ) CrossEntropy(a,\dot{a})=- \sum\dot{a} \cdot log(a) CrossEntropy(a,a˙)=a˙log(a)
交叉熵函数的图像为:
在这里插入图片描述
可以看到,当预测结果与实际结果越相符时,交叉熵越低;否则交叉熵会快速飙高以达到一个较大的惩罚。有人可能会有疑问:这如何解释LR模型和softmax模型的损失函数呢?

2. LR模型损失函数与CrossEntropy的关系

我们把LR模型的损失函数贴一下:
J ( x ; w , b ) = − 1 n ∑ i = 1 n ( q ( x i ) log ⁡ p ( x i ) + ( 1 − q ( x i ) ) log ⁡ ( 1 − p ( x i ) ) ) J(x;w,b) = -\frac1n\sum_{i=1}^n (q(x_i) \log p(x_i)+(1-q(x_i)) \log (1-p(x_i))) J(x;w,b)=n1i=1n(q(xi)logp(xi)+(1q(xi))log(1p(xi)))
提取出核心的部分:
− ( q ( x i ) log ⁡ p ( x i ) + ( 1 − q ( x i ) ) log ⁡ ( 1 − p ( x i ) ) ) (1) -(q(x_i) \log p(x_i)+(1-q(x_i)) \log (1-p(x_i)) \tag{1}) (q(xi)logp(xi)+(1q(xi))log(1p(xi)))(1)

设:该LR模型的标签集为 { T r u e , F a l s e } \{True,False\} {True,False},我们用 q ( T r u e ∣ x ) q(True|x) q(Truex) q ( F a l s e ∣ x ) q(False|x) q(Falsex)表示样本数据 x x x的实际标签数据。当 x x x的标签取 T r u e True True时, q ( T r u e ∣ x ) = 1 , q ( F a l s e ∣ x ) = 0 q(True|x)=1,q(False|x)=0 q(Truex)=1,q(Falsex)=0;当 x x x的标签取 F a l s e False False时, q ( T r u e ∣ x ) = 0 , q ( F a l s e ∣ x ) = 1 q(True|x)=0,q(False|x)=1 q(Truex)=0,q(Falsex)=1。式子 ( 1 ) (1) (1)可以改写为:
− ( q ( T r u e ∣ x i ) log ⁡ p ( T r u e ∣ x i ) + q ( F a l s e ∣ x i ) log ⁡ p ( F a l s e ∣ x i ) ) = − ∑ y = T r u e F a l s e q ( y ∣ x ) log ⁡ ( p ( y ∣ x ) ) -(q(True|x_i) \log p(True|x_i)+q(False|x_i) \log p(False|x_i)) = - \sum_{y=True}^{False}q(y|x)\log(p(y|x)) (q(Truexi)logp(Truexi)+q(Falsexi)logp(Falsexi))=y=TrueFalseq(yx)log(p(yx))
这个式子是交叉熵公式在二分类场景下的形式。因此这个LR模型的损失公式,其实是关于预测值与标签值之间的交叉熵公式。

3. softmax模型的损失函数与CrossEntropy的关系

同样贴下softmax的损失函数:
J ( x ; w , b ) = − 1 n ∑ i = 1 n log ⁡ exp ⁡ ( w y T x i ) ∑ c exp ⁡ ( w c T x i ) J(x;w,b) = -\frac1n \sum_{i=1}^n \log \frac{\exp(w_y^Tx_i)}{\sum_c \exp(w_c^Tx_i)} J(x;w,b)=n1i=1nlogcexp(wcTxi)exp(wyTxi)

上边这个函数是建立在一个前提上,即:测试数据集中所有数据的分类标签都是确定到一个具体分类。假设我们的标签集为 C = { c 1 , c 2 , . . . , c k } C=\{c_1,c_2,...,c_k\} C={c1,c2,...,ck},一共有k个分类,那么针对测试集中的样本数据 x x x,其标签数据 y y y为一个k维独热向量。也就是说,不允许有标签表示某个测试数据 x x x有一半可能属于 c 1 c_1 c1,另一半可能属于 c 2 c_2 c2
我们把这个公式的关键部分提取一下:
− ∑ log ⁡ exp ⁡ ( w y T x i ) ∑ c exp ⁡ ( w c T x i ) (2) -\sum \log \frac{\exp(w_y^Tx_i)}{\sum_c \exp(w_c^Tx_i)} \tag{2} logcexp(wcTxi)exp(wyTxi)(2)
由于:
exp ⁡ ( w y T x i ) ∑ c exp ⁡ ( w c T x i ) = p ( y ∣ x i ) \frac{\exp(w_y^Tx_i)}{\sum_c \exp(w_c^Tx_i)} = p(y|x_i) cexp(wcTxi)exp(wyTxi)=p(yxi)
p ( y ∣ x i ) p(y|x_i) p(yxi)替换可得:
− ∑ log ⁡ p ( y ∣ x i ) (3) -\sum \log p(y|x_i) \tag{3} logp(yxi)(3)
已知 y ∈ C y\in C yC,设 y = c k y=c_k y=ck,则式 ( 3 ) (3) (3)可以扩写为
− ∑ ( 0 ⋅ log ⁡ p ( c 1 ∣ x i ) + 0 ⋅ log ⁡ p ( c 2 ∣ x i ) + ⋯ + 0 ⋅ log ⁡ p ( c k − 1 ∣ x i ) + 1 ⋅ log ⁡ p ( y ∣ x i ) ) -\sum (0 \cdot \log p(c_1|x_i) + 0 \cdot \log p(c_2|x_i) + \cdots + 0 \cdot \log p(c_{k-1}|x_i ) + 1 \cdot \log p(y|x_i )) (0logp(c1xi)+0logp(c2xi)++0logp(ck1xi)+1logp(yxi))
上式可以写成交叉熵公式的形式:
− ∑ j = 1 k q ( y ∣ x i ) ⋅ log ⁡ p ( y ∣ x i ) -\sum_{j=1}^{k} q(y|x_i) \cdot \log p(y|x_i) j=1kq(yxi)logp(yxi)

4. 结论

CrossEntropy函数就是我们在学习LR模型和Softmax模型的时候经常遇到的目标函数的更加通用化的表示。不仅适用于多分类场景,也使用于训练数据的标签不唯一的情况,也就是某个训练数据 x x x的标签有50%的可能性为 c 1 c_1 c1,也有50%的可能性为 c 2 c_2 c2的情况。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值