两个高斯分布的KL散度&绘制动画

KL散度

KL散度用于衡量两个变量分布之间的差异性

K L ( P   ∣ ∣   Q ) = ∫ − ∞ + ∞ p ( x ) log ⁡ p ( x ) q ( x ) d x (1) KL(P\ ||\ Q)=\int_{-\infty}^{+\infty}p(x)\log\frac{p(x)}{q(x)}dx\tag{1} KL(P ∣∣ Q)=+p(x)logq(x)p(x)dx(1)

P、Q为随机变量X的两个概率分布;p、q为对应的概率密度函数

如果P,Q均为高斯分布,即:
P = N ( μ 1 , σ 1 2 ) Q = N ( μ 2 , σ 2 2 ) (2) P=\mathcal{N}(\mu_1,\sigma^2_1)\\ Q=\mathcal{N}(\mu_2,\sigma^2_2)\tag{2} P=N(μ1,σ12)Q=N(μ2,σ22)(2)
那么(1)可以化简为:
K L ( N ( μ 1 , σ 1 2 )   ∣ ∣   N ( μ 2 , σ 2 2 ) ) = log ⁡ σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 KL(\mathcal{N}(\mu_1,\sigma^2_1)\ ||\ \mathcal{N}(\mu_2,\sigma^2_2))=\log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} KL(N(μ1,σ12) ∣∣ N(μ2,σ22))=logσ1σ2+2σ22σ12+(μ1μ2)221
用matplotlib库绘制两个高斯分布的KL散度变化动画:

KL

红线:

  • μ \mu μ在(-10,10)之间变化
  • σ \sigma σ​始终为1

蓝线:

  • μ \mu μ始终为0
  • σ \sigma σ始终为1

代码如下:

import functools
from typing import List
from matplotlib.lines import Line2D
from matplotlib.patches import ConnectionPatch
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation


# 更新函数,用于在动画中更新正态分布的均值
def update(
    mean,
    pdf_line: Line2D,
    kl_div_line: Line2D,
    connection_line: ConnectionPatch,
    points: List[Line2D],
    x_min: float,
    x_max: float,
    step: float,
):
    # 生成两个正态分布的概率密度函数
    x1 = np.arange(x_min, x_max + step, step)
    y1 = 1 / (np.sqrt(2 * np.pi)) * np.exp(-((x1 - mean) ** 2) / 2)
    x2 = np.arange(x_min, mean + step, step)
    y2 = (1 + x2**2) / 2 - 0.5
    # 更新线的数据
    pdf_line.set_data(x1, y1)
    kl_div_line.set_data(x2, y2)
    connection_line.xy1 = (mean, 1 / (np.sqrt(2 * np.pi)))
    connection_line.xy2 = (mean, (1 + mean**2) / 2 - 0.5)
    points[0].set_data(np.expand_dims(connection_line.xy1, axis=-1))
    points[1].set_data(np.expand_dims(connection_line.xy2, axis=-1))


if __name__ == "__main__":
    # 创建图形和坐标轴
    fig, ax = plt.subplots(2, 1, sharex=True, figsize=(8, 6.6))
    fig.suptitle("Animation of two Gauss Distribution's KL divergence", x=0.50, y=0.92)

    x_min, x_max, step = -10, 10, 0.10
    x_ords = np.linspace(x_min, x_max, 200, endpoint=False)
    ax[0].set_xlim(x_min, x_max)
    ax[0].set_ylim(-0.2, 0.6)
    ax[0].set_ylabel("Probability Density")
    ax[1].set_xlim(x_min, x_max)
    ax[1].set_ylim(-5, 100)
    ax[1].set_xlabel("Mean")
    ax[1].set_ylabel("KL Divergence")
    # 绘制标准正态分布
    ax[0].plot(
        x_ords,
        1 / (np.sqrt(2 * np.pi)) * np.exp(-(x_ords**2) / 2),
        label="mean=0 & std=1",
    )
    # 初始化动画时要绘制的线
    (pdf_line,) = ax[0].plot([], [], color="red", label="mean=[-10, 10] & std=1")
    (kl_div_line,) = ax[1].plot([], [], color="purple")
    (point1,) = ax[0].plot([-10], [0], color="cyan", marker="o")
    (point2,) = ax[1].plot([-10], [0], color="cyan", marker="o")
    ax[0].legend(), ax[0].grid(), ax[1].grid()
    connection = ConnectionPatch(
        [-10, 0],
        [-10, 0],
        "data",
        "data",
        axesA=ax[0],
        axesB=ax[1],
        ls="dotted",
        lw=2,
        color="pink",
    )
    fig.add_artist(connection)
    # 创建动画
    animation = FuncAnimation(
        fig,
        func=functools.partial(
            update,
            pdf_line=pdf_line,
            kl_div_line=kl_div_line,
            connection_line=connection,
            points=[point1, point2],
            x_min=x_min,
            x_max=x_max,
            step=step,
        ),
        frames=np.arange(x_min, x_max, step),
        interval=50,
    )
    plt.show()
    animation.save("KL.gif")
  • 16
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值