本文主要是讲述Softmax和CrossEntropy的公式推导,并用代码进一步佐证。
1. Softmax前向计算
我们把
S
o
f
t
m
a
x
Softmax
Softmax输出的概率定义为
p
i
p_i
pi:
S
o
f
t
m
a
x
(
a
i
)
=
p
i
=
e
a
i
∑
j
N
e
a
j
Softmax(a_i) = p_i = \frac {e^{a_i}} {\sum_j^N e^{a_j}}
Softmax(ai)=pi=∑jNeajeai
模型输出
[
a
1
,
a
2
,
.
.
.
,
a
N
]
[a_1, a_2, ..., a_N]
[a1,a2,...,aN],共N个值。
其中
a
i
a_i
ai代表第
i
i
i个输出值,
p
i
p_i
pi代表第
i
i
i个输出值经过
S
o
f
t
m
a
x
Softmax
Softmax计算过后的概率。
且
p
1
+
p
2
+
.
.
.
+
p
N
=
1
p_1+p_2+...+p_N=1
p1+p2+...+pN=1
1.1 数值稳定
因为Softmax涉及到指数函数,且底数 e e e 大于1,在计算机中是可能会有溢出风险的。结合指数、对数函数的转换规则,我们可以制定一些数值稳定的优化策略。(当然这些是在框架中实现的,学习更多是为了扩展视野)
数值稳定的主要思路在于
a
i
a_i
ai减去
A
=
[
a
1
,
a
2
,
.
.
.
,
a
N
]
A=[a_1, a_2, ..., a_N]
A=[a1,a2,...,aN]中的最大值
m
a
x
(
A
)
max(A)
max(A)
p
i
=
e
a
i
∑
j
N
e
a
j
=
C
⋅
e
a
i
C
⋅
∑
j
N
e
a
j
=
e
log
(
C
)
⋅
e
a
i
e
log
(
C
)
⋅
∑
j
N
e
a
j
=
e
a
i
+
log
(
C
)
∑
j
N
e
a
j
+
log
(
C
)
=
e
a
i
−
m
a
x
(
A
)
∑
j
N
e
a
j
−
m
a
x
(
A
)
\begin{aligned} p_i & = \frac {e^{a_i}} {\sum_j^N e^{a_j}}\\ & = \frac {C \cdot e^{a_i}} {C \cdot \sum_j^N e^{a_j}}\\ & = \frac {e^{\log(C)} \cdot e^{a_i}} {e^{\log(C)} \cdot \sum_j^N e^{a_j}}\\ & = \frac {e^{a_i + \log(C)}} {\sum_j^N e^{a_j + \log(C)}}\\ & = \frac {e^{a_i - max(A)}} {\sum_j^N e^{a_j - max(A)}}\\ \end{aligned}
pi=∑jNeajeai=C⋅∑jNeajC⋅eai=elog(C)⋅∑jNeajelog(C)⋅eai=∑jNeaj+log(C)eai+log(C)=∑jNeaj−max(A)eai−max(A)
因为C是常数,
l
o
g
(
C
)
log(C)
log(C)也是常数,所以我们可以类比到分子分母同时加上 -max(A),并不会改变
p
i
p_i
pi的计算结果。
A中各项均减去最大值,就能确保A中所有项都不会上溢出。
2. Cross-Entropy前向计算
我们把交叉熵损失(Cross Entropy Loss)定义为
H
H
H,同时传入Softmax得出的概率
p
i
p_i
pi及其对应的Label
y
i
y_i
yi:
C
E
L
o
s
s
(
y
i
,
p
i
)
=
H
(
y
i
,
p
i
)
=
−
∑
i
N
y
i
⋅
log
(
p
i
)
CELoss(y_i, p_i) = H(y_i, p_i) = -\sum_i^Ny_i \cdot \log (p_i)
CELoss(yi,pi)=H(yi,pi)=−i∑Nyi⋅log(pi)
在多分类问题中,我们的 Label 通常以独热码(one-hot)的形式展现和训练,因此在
Y
=
[
y
1
,
y
2
,
.
.
.
,
y
N
]
Y=[y_1, y_2, ..., y_N]
Y=[y1,y2,...,yN] 中,只有一项为
1
1
1,其余项为
0
0
0,即
[
0
,
0
,
.
.
.
,
1
,
.
.
.
,
0
,
0
]
[0, 0, ..., 1, ..., 0, 0]
[0,0,...,1,...,0,0]。
所以
H
(
y
i
,
p
i
)
H(y_i, p_i)
H(yi,pi) 也等于
−
y
i
⋅
log
(
p
i
)
-y_i \cdot \log(p_i)
−yi⋅log(pi),
y
i
=
1
y_i=1
yi=1 对应Label的类别。
3. Softmax反向传播求导
因为Softmax+Cross-Entropy的反向传播包含基于Softmax的求导公式,所以我们先推导Softmax的导数。
据
S
o
f
t
m
a
x
Softmax
Softmax 公式可知,每个
p
i
p_i
pi 均是所有
a
a
a 都有参与运算的(在分母的累加中体现),因此梯度的形式为:
∂
p
i
∂
a
j
=
∂
(
e
a
i
∑
j
N
e
a
j
)
∂
a
j
\frac {\partial p_i} {\partial a_j} = \frac{\partial (\frac {e^{a_i}}{\sum_j^N e^{a_j}})}{\partial a_j}
∂aj∂pi=∂aj∂(∑jNeajeai)
因为
i
i
i和
j
j
j可能不相同,所以
i
i
i 和
j
j
j 的关系要分类讨论。
这里要先复习下含分母的求导公式:
(
h
(
x
)
g
(
x
)
)
′
=
h
′
(
x
)
⋅
g
(
x
)
−
h
(
x
)
⋅
g
′
(
x
)
g
(
x
)
2
(\frac{h(x)}{g(x)})^\prime = \frac{h'(x)\cdot g(x)-h(x)\cdot g'(x)}{g(x)^2}
(g(x)h(x))′=g(x)2h′(x)⋅g(x)−h(x)⋅g′(x)
并且简化一下符号:
∑
j
N
e
a
j
=
∑
\sum_j^Ne^{a_j} = \sum
j∑Neaj=∑
当
i
=
j
i=j
i=j:
∂
p
i
∂
a
j
=
e
a
i
⋅
∑
−
e
a
i
⋅
e
a
j
∑
⋅
∑
=
e
a
i
⋅
(
∑
−
e
a
j
)
∑
⋅
∑
=
p
i
⋅
(
1
−
p
j
)
\begin{aligned} \frac {\partial p_i} {\partial a_j} & = \frac{e^{a_i} \cdot \sum - e^{a_i} \cdot e^{a_j}}{\sum \cdot \sum} \\ & = \frac{e^{a_i} \cdot (\sum - e^{a_j})}{\sum \cdot \sum}\\ & = p_i \cdot (1 - p_j)\\ \end{aligned}
∂aj∂pi=∑⋅∑eai⋅∑−eai⋅eaj=∑⋅∑eai⋅(∑−eaj)=pi⋅(1−pj)
当
i
≠
j
i \neq j
i=j(对
a
j
a_j
aj求导,相当于
e
a
i
e^{a_i}
eai是常数,导数为 0):
∂
p
i
∂
a
j
=
0
⋅
∑
−
e
a
i
⋅
e
a
j
∑
⋅
∑
=
−
p
i
⋅
p
j
\begin{aligned} \frac {\partial p_i} {\partial a_j} & = \frac{0 \cdot \sum - e^{a_i} \cdot e^{a_j}}{\sum \cdot \sum} \\ & = - p_i \cdot p_j \end{aligned}
∂aj∂pi=∑⋅∑0⋅∑−eai⋅eaj=−pi⋅pj
4. Cross-Entropy + Softmax反向传播求导
Cross-Entropy的导数为:
H
′
(
y
i
,
p
i
)
=
−
∑
i
N
y
i
1
p
i
H'(y_i, p_i) = -\sum_i^Ny_i\frac{1}{p_i}
H′(yi,pi)=−i∑Nyipi1
根据链式法则(Chain Rule),整体损失对于
a
j
a_j
aj的导数为:
∂
H
∂
a
j
=
∂
H
∂
p
i
⋅
∂
p
i
∂
a
j
=
(
−
∑
i
y
i
1
p
i
)
⋅
∂
p
i
∂
a
j
−
−
−
①
\frac {\partial H}{\partial a_j} = \frac {\partial H}{\partial p_i} \cdot \frac {\partial p_i}{\partial a_j} = (-\sum_iy_i\frac{1}{p_i}) \cdot \frac {\partial p_i}{\partial a_j}---①
∂aj∂H=∂pi∂H⋅∂aj∂pi=(−i∑yipi1)⋅∂aj∂pi−−−①
当
i
=
j
i=j
i=j:
①
=
−
∑
i
=
j
y
i
1
p
i
⋅
p
i
⋅
(
1
−
p
j
)
=
−
∑
i
=
j
y
i
⋅
(
1
−
p
j
)
=
−
y
i
+
y
i
p
j
(
因
为
只
有
i
,
可
以
把
∑
去
掉
)
−
−
−
②
\begin{aligned} ① & = -\sum_{i=j}y_i\frac{1}{p_i} \cdot p_i \cdot (1 - p_j) \\ & = -\sum_{i=j}y_i\cdot (1 - p_j) \\ & = -y_i + y_ip_j (因为只有i,可以把\sum去掉)---② \end{aligned}
①=−i=j∑yipi1⋅pi⋅(1−pj)=−i=j∑yi⋅(1−pj)=−yi+yipj(因为只有i,可以把∑去掉)−−−②
当
i
≠
j
i \neq j
i=j:
①
=
−
∑
i
≠
j
y
i
1
p
i
⋅
(
−
p
i
⋅
p
j
)
=
∑
i
≠
j
y
i
p
j
−
−
−
③
\begin{aligned} ① & = -\sum_{i \neq j}y_i\frac{1}{p_i} \cdot (-p_i \cdot p_j) \\ & = \sum_{i \neq j}y_i p_j --- ③ \end{aligned}
①=−i=j∑yipi1⋅(−pi⋅pj)=i=j∑yipj−−−③
因为②和③其实是①的互斥情况,所以可以合并:
①
=
②
+
③
(
记
住
在
②
中
,
i
=
j
)
=
−
y
i
+
y
i
p
j
+
∑
i
≠
j
y
i
p
j
=
−
y
i
+
(
∑
i
=
j
y
i
p
j
+
∑
i
≠
j
y
i
p
j
)
=
−
y
i
+
∑
i
N
y
i
p
j
(
因
为
y
i
是
o
n
e
−
h
o
t
,
∑
i
N
y
i
=
1
)
=
p
j
−
y
j
(
因
为
②
中
i
=
j
,
则
y
i
=
y
j
)
\begin{aligned} ① & = ②+③(记住在②中,i=j) \\ & = -y_i + y_ip_j + \sum_{i \neq j}y_i p_j \\ & = -y_i + (\sum_{i=j}y_ip_j + \sum_{i \neq j}y_i p_j) \\ & = -y_i + \sum_i^N y_ip_j(因为y_i是one-hot,\sum_i^N y_i=1) \\ & = p_j - y_j(因为②中i=j,则y_i=y_j) \end{aligned}
①=②+③(记住在②中,i=j)=−yi+yipj+i=j∑yipj=−yi+(i=j∑yipj+i=j∑yipj)=−yi+i∑Nyipj(因为yi是one−hot,i∑Nyi=1)=pj−yj(因为②中i=j,则yi=yj)
整个Softmax+CrossEntropy的求导推导下来发现,
H
H
H对于
a
j
a_j
aj的梯度值,就是让他的
p
j
p_j
pj去减对应的label值(
y
j
y_j
yj)。
举例P = [0.5, 0.3, 0.2],Y=[1, 0, 0],对应的导数就是 [-0.5, 0.3, 0.2]。
5. 代码验证
先看下x在softmax+cross entorpy前向计算并且BP后,所产生的梯度是多少,即 ∂ H ∂ a j \frac{\partial H}{\partial a_j} ∂aj∂H,在这个例子中分别对a1,a2,a3求梯度:
x = torch.randn((1, 3), requires_grad=True)
# tensor([[-0.3876, 0.2697, -1.6527]], requires_grad=True)
y = torch.randint(3, (1,), dtype=torch.int64)
# tensor([1])
loss = F.cross_entropy(x, y)
# F.cross_entropy含了softmax+cross_entropy
# 因此直接调用即可,无需先使用F.softmax
print(loss)
# tensor(0.5095, grad_fn=<NllLossBackward>)
loss.backward()
print(x.grad)
# tensor([[ 0.3113, -0.3992, 0.0879]])
下面再看下 p i p_i pi 的值:
F.softmax(x, dim=1)
# tensor([[0.3113, 0.6008, 0.0879]], grad_fn=<SoftmaxBackward>)
发现没有!发现没有!除了 a 1 . g r a d a_1.grad a1.grad 比 p 1 p_1 p1 减了1之外,其他都没变!正正验证了上面的公式推导!
6. 总结
总结一下,在多分类问题中,softmax+cross entropy是比较普遍,且计算速度较快的损失函数(loss function),因为它的梯度仅仅只用把概率值(pi)减去标签(yi)即可!
这在训练的初期,可以提供较快的训练速度,以提供后续优化的方向。当然,后续也包括对损失函数的优化!