KL散度(相对熵)
1.定义
先给出公式:
K
L
(
p
(
x
)
∣
∣
q
(
x
)
)
=
∫
p
(
x
)
log
p
(
x
)
q
(
x
)
d
x
=
∫
p
(
x
)
log
p
(
x
)
d
x
−
∫
p
(
x
)
log
q
(
x
)
d
x
(1)
KL(p(x)||q(x))=\int p(x)\log \frac{p(x)}{q(x)}dx=\int p(x)\log p(x)dx-\int p(x)\log q(x)dx\tag{1}
KL(p(x)∣∣q(x))=∫p(x)logq(x)p(x)dx=∫p(x)logp(x)dx−∫p(x)logq(x)dx(1)
(1)中第二项
−
∫
p
(
x
)
log
q
(
x
)
d
x
-\int p(x)\log q(x)dx
−∫p(x)logq(x)dx就是所谓的交叉熵。
(1)恒
≥
0
\ge 0
≥0,其中第一项是常数,因此把交叉熵作为代价函数,其最小时
K
L
=
0
KL=0
KL=0,
q
(
x
)
=
p
(
x
)
q(x)=p(x)
q(x)=p(x),学到的模型最优。
3.交叉熵损失函数避免梯度消失
给定
X
N
×
p
X_{N\times p}
XN×p,如果采用平方损失函数,我们有如下模型:
S
N
×
1
=
X
N
×
p
W
p
×
1
Y
N
×
1
=
σ
(
S
N
×
1
)
L
o
s
s
=
1
2
∣
∣
Y
−
Y
^
∣
∣
2
2
S_{N\times 1}=X_{N\times p}W_{p\times 1}\\ Y_{N\times 1}=\sigma(S_{N\times 1})\\ Loss=\frac{1}{2}||Y-\hat{Y}||_2^2
SN×1=XN×pWp×1YN×1=σ(SN×1)Loss=21∣∣Y−Y^∣∣22根据链式求导法则:
∇
w
L
=
∂
L
∂
Y
∂
Y
∂
S
∂
S
∂
W
=
X
T
⋅
σ
(
S
)
⊙
σ
(
S
)
⊙
(
1
−
σ
(
S
)
)
\nabla _wL=\frac{\partial L}{\partial Y}\frac{\partial Y}{\partial S}\frac{\partial S}{\partial W}\\=X^T\cdot \sigma(S)\odot \sigma(S)\odot (1-\sigma(S))
∇wL=∂Y∂L∂S∂Y∂W∂S=XT⋅σ(S)⊙σ(S)⊙(1−σ(S))因为
l
i
m
i
t
s
x
−
>
∞
σ
(
x
)
=
0
\underset{x->\infty}{limits}\ \sigma(x)=0
x−>∞limits σ(x)=0,因此
σ
\sigma
σ函数会导致梯度消失现象。
如果采用交叉熵损失函数,我们有如下模型:
S
N
×
1
=
X
N
×
p
W
p
×
1
Y
N
×
1
=
σ
(
S
N
×
1
)
L
o
s
s
=
∑
i
=
1
N
y
i
^
log
y
i
+
(
1
−
y
i
^
)
log
(
1
−
y
i
)
S_{N\times 1}=X_{N\times p}W_{p\times 1}\\ Y_{N\times 1}=\sigma(S_{N\times 1})\\ Loss=\sum_{i=1}^N \hat{y_i}\log y_i+(1-\hat{y_i})\log (1-y_i)
SN×1=XN×pWp×1YN×1=σ(SN×1)Loss=i=1∑Nyi^logyi+(1−yi^)log(1−yi)根据链式求导法则:
∂
L
∂
Y
=
[
⋯
y
i
^
y
i
−
1
−
y
^
i
1
−
y
i
⋯
]
=
[
⋯
y
i
^
−
y
i
y
i
(
1
−
y
i
)
⋯
]
∂
Y
∂
S
=
[
⋯
y
i
(
1
−
y
i
)
⋯
]
\frac{\partial L}{\partial Y}=\begin{bmatrix}\cdots\\\frac{\hat{y_i}}{y_i}-\frac{1-\hat y_i}{1-y_i}\\\cdots\end{bmatrix}=\begin{bmatrix}\cdots\\\frac{\hat{y_i}-y_i}{y_i{(1-y_i)}}\\\cdots\end{bmatrix}\\\frac{\partial Y}{\partial S}=\begin{bmatrix}\cdots\\y_i{(1-y_i)}\\\cdots\end{bmatrix}
∂Y∂L=⎣⎡⋯yiyi^−1−yi1−y^i⋯⎦⎤=⎣⎡⋯yi(1−yi)yi^−yi⋯⎦⎤∂S∂Y=⎣⎡⋯yi(1−yi)⋯⎦⎤因此
∇
w
L
=
X
T
⋅
[
⋯
y
i
^
−
y
i
⋯
]
=
X
T
⋅
(
Y
^
−
Y
)
\nabla _w L=X^T\cdot \begin{bmatrix}\cdots\\\hat{y_i}-y_i\\\cdots\end{bmatrix}=X^T\cdot (\hat{Y}-Y)
∇wL=XT⋅⎣⎡⋯yi^−yi⋯⎦⎤=XT⋅(Y^−Y)消去了
σ
(
S
)
⊙
σ
(
S
)
\sigma (S)\odot \sigma (S)
σ(S)⊙σ(S),解决了梯度消失问题。
2.非负性证明
已知 ln x ≤ x − 1 f o r x ∈ ( 0 , + ∞ ) (2) \ln x\le x-1\tag{2}\\for\ x\in (0,+\infty) lnx≤x−1for x∈(0,+∞)(2)因此 − K L ( p ( x ) ∣ ∣ q ( x ) ) = ∫ p ( x ) log q ( x ) p ( x ) d x ≤ ∫ p ( x ) ( q ( x ) p ( x ) − 1 ) d x = ∫ ( q ( x ) − p ( x ) ) d x = ∫ q ( x ) d x − ∫ p ( x ) d x = 1 − 1 = 0 (3) \begin{aligned}&-KL(p(x)||q(x))=\int p(x)\log \frac{q(x)}{p(x)}dx\\&\le\int p(x)(\frac{q(x)}{p(x)}-1)dx\\&=\int(q(x)-p(x))dx\\&=\int q(x)dx-\int p(x)dx\\&=1-1=0\end{aligned}\tag{3} −KL(p(x)∣∣q(x))=∫p(x)logp(x)q(x)dx≤∫p(x)(p(x)q(x)−1)dx=∫(q(x)−p(x))dx=∫q(x)dx−∫p(x)dx=1−1=0(3)因此 K L ( p ( x ) ∣ ∣ q ( x ) ) ≥ 0 K L ( p ( x ) ∣ ∣ q ( x ) ) = 0 i f f . p ( x ) = q ( x ) (4) KL(p(x)||q(x))\ge 0\\KL(p(x)||q(x))=0\ iff.p(x)=q(x)\tag{4} KL(p(x)∣∣q(x))≥0KL(p(x)∣∣q(x))=0 iff.p(x)=q(x)(4)