KL散度
KL散度用于衡量两个变量分布之间的差异性
K L ( P ∣ ∣ Q ) = ∫ − ∞ + ∞ p ( x ) log p ( x ) q ( x ) d x (1) KL(P\ ||\ Q)=\int_{-\infty}^{+\infty}p(x)\log\frac{p(x)}{q(x)}dx\tag{1} KL(P ∣∣ Q)=∫−∞+∞p(x)logq(x)p(x)dx(1)
P、Q为随机变量X的两个概率分布;p、q为对应的概率密度函数
如果P,Q均为高斯分布,即:
P
=
N
(
μ
1
,
σ
1
2
)
Q
=
N
(
μ
2
,
σ
2
2
)
(2)
P=\mathcal{N}(\mu_1,\sigma^2_1)\\ Q=\mathcal{N}(\mu_2,\sigma^2_2)\tag{2}
P=N(μ1,σ12)Q=N(μ2,σ22)(2)
那么(1)可以化简为:
K
L
(
N
(
μ
1
,
σ
1
2
)
∣
∣
N
(
μ
2
,
σ
2
2
)
)
=
log
σ
2
σ
1
+
σ
1
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
−
1
2
KL(\mathcal{N}(\mu_1,\sigma^2_1)\ ||\ \mathcal{N}(\mu_2,\sigma^2_2))=\log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2}
KL(N(μ1,σ12) ∣∣ N(μ2,σ22))=logσ1σ2+2σ22σ12+(μ1−μ2)2−21
用matplotlib库绘制两个高斯分布的KL散度变化动画:
红线:
- μ \mu μ在(-10,10)之间变化
- σ \sigma σ始终为1
蓝线:
- μ \mu μ始终为0
- σ \sigma σ始终为1
代码如下:
import functools
from typing import List
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
# 更新函数,用于在动画中更新正态分布的均值
def update(
mean,
pdf_line: Line2D,
kl_div_line: Line2D,
connection_line: ConnectionPatch,
points: List[Line2D],
x_min: float,
x_max: float,
step: float,
):
# 生成两个正态分布的概率密度函数
x1 = np.arange(x_min, x_max + step, step)
y1 = 1 / (np.sqrt(2 * np.pi)) * np.exp(-((x1 - mean) ** 2) / 2)
x2 = np.arange(x_min, mean + step, step)
y2 = (1 + x2**2) / 2 - 0.5
# 更新线的数据
pdf_line.set_data(x1, y1)
kl_div_line.set_data(x2, y2)
connection_line.xy1 = (mean, 1 / (np.sqrt(2 * np.pi)))
connection_line.xy2 = (mean, (1 + mean**2) / 2 - 0.5)
points[0].set_data(np.expand_dims(connection_line.xy1, axis=-1))
points[1].set_data(np.expand_dims(connection_line.xy2, axis=-1))
if __name__ == "__main__":
# 创建图形和坐标轴
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(8, 6.6))
fig.suptitle("Animation of two Gauss Distribution's KL divergence", x=0.50, y=0.92)
x_min, x_max, step = -10, 10, 0.10
x_ords = np.linspace(x_min, x_max, 200, endpoint=False)
ax[0].set_xlim(x_min, x_max)
ax[0].set_ylim(-0.2, 0.6)
ax[0].set_ylabel("Probability Density")
ax[1].set_xlim(x_min, x_max)
ax[1].set_ylim(-5, 100)
ax[1].set_xlabel("Mean")
ax[1].set_ylabel("KL Divergence")
# 绘制标准正态分布
ax[0].plot(
x_ords,
1 / (np.sqrt(2 * np.pi)) * np.exp(-(x_ords**2) / 2),
label="mean=0 & std=1",
)
# 初始化动画时要绘制的线
(pdf_line,) = ax[0].plot([], [], color="red", label="mean=[-10, 10] & std=1")
(kl_div_line,) = ax[1].plot([], [], color="purple")
(point1,) = ax[0].plot([-10], [0], color="cyan", marker="o")
(point2,) = ax[1].plot([-10], [0], color="cyan", marker="o")
ax[0].legend(), ax[0].grid(), ax[1].grid()
connection = ConnectionPatch(
[-10, 0],
[-10, 0],
"data",
"data",
axesA=ax[0],
axesB=ax[1],
ls="dotted",
lw=2,
color="pink",
)
fig.add_artist(connection)
# 创建动画
animation = FuncAnimation(
fig,
func=functools.partial(
update,
pdf_line=pdf_line,
kl_div_line=kl_div_line,
connection_line=connection,
points=[point1, point2],
x_min=x_min,
x_max=x_max,
step=step,
),
frames=np.arange(x_min, x_max, step),
interval=50,
)
plt.show()
animation.save("KL.gif")