Softmax 与交叉熵损失函数的反向传播公式推导
一、正向传播
用一个 X 举例,假设 S 为最后一层全连接层的输出,S 是一个长度为 c 的行向量,其中元素的含义为 c 个类分别的得分,即
s
1
s_1
s1 为 X 在第一个类的得分,以此类推。
S
=
{
s
1
,
s
2
,
s
3
,
…
,
s
c
}
(1)
\huge S=\{s_1,\ s_2,\ s_3,\ \dots,\ s_c\}\tag{1}
S={s1, s2, s3, …, sc}(1)
然后将 S 输入到 Softmax,输出一个长度为 c 的行向量 P,其中元素为各类别的概率。
P
=
{
p
1
,
p
2
,
p
3
,
…
,
p
c
}
(2)
\huge P=\{p_1,\ p_2,\ p_3,\ \dots,\ p_c\}\tag{2}
P={p1, p2, p3, …, pc}(2)
Softmax 每个元素的计算公式:
p
i
=
e
s
i
∑
j
=
1
c
e
s
j
=
e
s
i
e
s
1
+
e
s
2
+
e
s
3
+
⋯
+
e
s
c
(3)
\huge p_{i}=\frac{e^{s_{i}}}{\sum\limits_{j=1}^{c}e^{s_{j}}}=\frac{e^{s_{i}}}{e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}}}\tag{3}
pi=j=1∑cesjesi=es1+es2+es3+⋯+escesi(3)
Loss function 为交叉熵损失函数,输入为 P 和真实标签 Y,Y也是一个长度为 c 的行向量。多数情况下对于分类问题,Y 会是一个 one-hot 向量,即 Y 中只有一个元素为 1,其余元素都为 0,其中 1 的下标表示 X 的类别。此文过程中先考虑复杂的情况,即 Y 中可能有多个小于 1 的元素,或者干脆假设 Y 中每个元素都是 1/c。本文末尾再给出 Y 为 one-hot 向量的简单情况。
下面为交叉熵损失函数的公式:
L
o
s
s
=
−
∑
j
=
1
c
y
j
ln
p
j
=
−
y
1
ln
p
1
−
y
2
ln
p
2
−
y
3
ln
p
3
−
⋯
−
y
c
ln
p
c
(
4
)
\huge\ Loss=-\sum\limits_{j=1}^{c}y_j\ln p_j\\ \huge =-y_1\ln p_1-y_2\ln p_2-y_3\ln p_3-\dots-y_c\ln p_c\ \ (4)
Loss=−j=1∑cyjlnpj=−y1lnp1−y2lnp2−y3lnp3−⋯−yclnpc (4)
二、反向传播
我们要求的误差项如下:
∂
L
∂
S
=
{
∂
L
∂
s
1
,
∂
L
∂
s
2
,
∂
L
∂
s
3
,
…
,
∂
L
∂
s
c
}
(5)
\huge \frac{\partial L}{\partial S}=\{\frac{\partial L}{\partial s_1},\ \frac{\partial L}{\partial s_2},\ \frac{\partial L}{\partial s_3},\ \dots,\ \frac{\partial L}{\partial s_c}\}\tag{5}
∂S∂L={∂s1∂L, ∂s2∂L, ∂s3∂L, …, ∂sc∂L}(5)
我们计算其中任意一个元素,比如第三个 ∂ L ∂ s 3 \huge \frac{\partial L}{\partial s_3} ∂s3∂L:
这里需要注意,由公式(4)可以看到 L 会受到
p
1
p_1
p1 到
p
c
p_c
pc 的影响,而看公式(3),每一个
p
p
p 的分母都包含了
s
3
s_3
s3,所以下面这个公式需要包含从
s
3
s_3
s3 到每一个
p
p
p 再到 L 的路径。
∂
L
∂
s
3
=
∑
j
=
1
c
∂
L
∂
p
j
∂
p
j
∂
s
3
=
∂
L
∂
p
1
∂
p
1
∂
s
3
+
∂
L
∂
p
2
∂
p
2
∂
s
3
+
∂
L
∂
p
3
∂
p
3
∂
s
3
+
…
∂
L
∂
p
c
∂
p
c
∂
s
3
(6)
\huge \frac{\partial L}{\partial s_3}=\sum\limits_{j=1}^{c}\frac{\partial L}{\partial p_j}\frac{\partial p_j}{\partial s_3}\\ \huge =\frac{\partial L}{\partial p_1}\frac{\partial p_1}{\partial s_3}+\frac{\partial L}{\partial p_2}\frac{\partial p_2}{\partial s_3}+\frac{\partial L}{\partial p_3}\frac{\partial p_3}{\partial s_3}+\dots\frac{\partial L}{\partial p_c}\frac{\partial p_c}{\partial s_3}\tag{6}
∂s3∂L=j=1∑c∂pj∂L∂s3∂pj=∂p1∂L∂s3∂p1+∂p2∂L∂s3∂p2+∂p3∂L∂s3∂p3+…∂pc∂L∂s3∂pc(6)
这里面的项有两种情况(因为
p
3
p_3
p3 对
s
3
s_3
s3 求导与其他的
p
p
p 对
s
3
s_3
s3 求导不一样),第一种情况为不含
p
3
p_3
p3 的项,例如第一项:
∂
L
∂
p
1
=
−
y
1
1
p
1
(7)
\huge \frac{\partial L}{\partial p_1}=-y_1\frac {1}{p_1}\tag{7}
∂p1∂L=−y1p11(7)
∂ p 1 ∂ s 3 = ∂ ( e s 1 e s 1 + e s 2 + e s 3 + ⋯ + e s c ) ∂ s 3 = 0 ∗ ( e s 1 + e s 2 + e s 3 + ⋯ + e s c ) − e s 1 ∗ e s 3 ( e s 1 + e s 2 + e s 3 + ⋯ + e s c ) 2 = − e s 1 ∗ e s 3 ( e s 1 + e s 2 + e s 3 + ⋯ + e s c ) 2 = − p 1 ∗ p 3 (8) \huge \frac{\partial p_1}{\partial s_3}=\frac{\partial(\frac{e^{s_{1}}}{e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}}})}{\partial s_3}\\ \huge =\frac{0*(e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}})-e^{s_{1}}*e^{s_3}}{(e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}})^2}\\ \huge =\frac{-e^{s_{1}}*e^{s_3}}{(e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}})^2}\\ \huge =-p_1*p_3\tag{8} ∂s3∂p1=∂s3∂(es1+es2+es3+⋯+esces1)=(es1+es2+es3+⋯+esc)20∗(es1+es2+es3+⋯+esc)−es1∗es3=(es1+es2+es3+⋯+esc)2−es1∗es3=−p1∗p3(8)
把(7)和(8)乘起来得到(6)中的第一项:
∂
L
∂
p
1
∂
p
1
∂
s
3
=
(
−
y
1
1
p
1
)
∗
(
−
p
1
∗
p
3
)
=
y
1
p
3
(9)
\huge \frac{\partial L}{\partial p_1}\frac{\partial p_1}{\partial s_3}=(-y_1\frac {1}{p_1})*(-p_1*p_3)\\ \huge =y_1p_3\tag{9}
∂p1∂L∂s3∂p1=(−y1p11)∗(−p1∗p3)=y1p3(9)
第二种情况为(6)中的第三项:
∂
L
∂
p
3
=
−
y
3
1
p
3
(10)
\huge \frac{\partial L}{\partial p_3}=-y_3\frac {1}{p_3}\tag{10}
∂p3∂L=−y3p31(10)
∂
p
3
∂
s
3
=
∂
(
e
s
3
e
s
1
+
e
s
2
+
e
s
3
+
⋯
+
e
s
c
)
∂
s
3
=
e
3
s
∗
(
e
s
1
+
e
s
2
+
e
s
3
+
⋯
+
e
s
c
)
−
e
s
3
∗
e
s
3
(
e
s
1
+
e
s
2
+
e
s
3
+
⋯
+
e
s
c
)
2
=
p
3
−
(
p
3
)
2
(11)
\huge \frac{\partial p_3}{\partial s_3}=\frac{\partial(\frac{e^{s_{3}}}{e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}}})}{\partial s_3}\\ \huge =\frac{e^s_{3}*(e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}})-e^{s_{3}}*e^{s_3}}{(e^{s_{1}}+e^{s_{2}}+e^{s_{3}}+\dots+e^{s_{c}})^2}\\ \huge =p_3-(p_3)^2\tag{11}
∂s3∂p3=∂s3∂(es1+es2+es3+⋯+esces3)=(es1+es2+es3+⋯+esc)2e3s∗(es1+es2+es3+⋯+esc)−es3∗es3=p3−(p3)2(11)
把(10)和(11)乘起来得到(6)中的第三项:
∂ L ∂ p 3 ∂ p 3 ∂ s 3 = ( − y 3 1 p 3 ) ∗ ( p 3 − ( p 3 ) 2 ) = y 3 p 3 − y 3 \huge \frac{\partial L}{\partial p_3}\frac{\partial p_3}{\partial s_3} \huge =(-y_3\frac {1}{p_3})*(p_3-(p_3)^2)\\ \huge = y_3p_3-y_3 ∂p3∂L∂s3∂p3=(−y3p31)∗(p3−(p3)2)=y3p3−y3
故公式(6)也就等于:
∂
L
∂
s
3
=
y
1
p
3
+
y
2
p
3
+
(
y
3
p
3
−
y
3
)
+
⋯
+
y
c
p
3
=
p
3
∑
j
=
1
c
y
j
−
y
3
(12)
\huge \frac{\partial L}{\partial s_3}=y_1p_3+y_2p_3+(y_3p_3-y_3)+\dots+y_cp_3\\ \huge =p_3\sum\limits_{j=1}^{c}y_j-y_3\tag{12}
∂s3∂L=y1p3+y2p3+(y3p3−y3)+⋯+ycp3=p3j=1∑cyj−y3(12)
到这里我们也就得到了公式(5)中的第一项。
整理一下得到总的误差项:
∂
L
∂
S
=
{
∂
L
∂
s
1
,
∂
L
∂
s
2
,
∂
L
∂
s
3
,
…
,
∂
L
∂
s
c
}
=
{
p
1
∑
j
=
1
c
y
j
−
y
1
,
p
2
∑
j
=
1
c
y
j
−
y
2
,
p
3
∑
j
=
1
c
y
j
−
y
3
,
…
,
p
c
∑
j
=
1
c
y
j
−
y
c
}
\huge \frac{\partial L}{\partial S}=\{\frac{\partial L}{\partial s_1},\ \frac{\partial L}{\partial s_2},\ \frac{\partial L}{\partial s_3},\ \dots,\ \frac{\partial L}{\partial s_c}\}\\ \large =\{p_1\sum\limits_{j=1}^{c}y_j-y_1,\ p_2\sum\limits_{j=1}^{c}y_j-y_2,\ p_3\sum\limits_{j=1}^{c}y_j-y_3,\ \dots,\ p_c\sum\limits_{j=1}^{c}y_j-y_c\}
∂S∂L={∂s1∂L, ∂s2∂L, ∂s3∂L, …, ∂sc∂L}={p1j=1∑cyj−y1, p2j=1∑cyj−y2, p3j=1∑cyj−y3, …, pcj=1∑cyj−yc}
最后来考虑常见的简单情况,即 Y 为 one-hot 向量时,假设
y
i
y_i
yi 为1,其余元素都为 0,此时的
∂
L
∂
S
\huge \frac{\partial L}{\partial S}
∂S∂L 为:
∂
L
∂
S
=
{
p
1
y
i
,
p
2
y
i
,
p
3
y
i
,
…
,
p
i
y
i
−
y
i
,
…
,
p
c
y
i
}
=
{
p
1
,
p
2
,
p
3
,
…
,
p
i
−
1
,
…
,
p
c
}
=
P
−
Y
(14)
\huge \frac{\partial L}{\partial S} \huge =\{p_1y_i,\ p_2y_i,\ p_3y_i,\ \dots,\ p_iy_i-y_i,\ \dots,\ p_cy_i\}\\ \huge =\{p_1,\ p_2,\ p_3,\ \dots,\ p_i-1,\ \dots,\ p_c\}\\ \huge =P-Y\tag{14}
∂S∂L={p1yi, p2yi, p3yi, …, piyi−yi, …, pcyi}={p1, p2, p3, …, pi−1, …, pc}=P−Y(14)