KL散度和JS散度学习笔记

  因为现在很多机器学习的算法中,都需要衡量两个概率分布的差异,特别是近年来比较火的生成模型领域。因此,我最近温习了一下KL散度和JS散度这两种衡量不同分布差异的方法,这里记录一下。注意,我们这里选取的都是离散型变量,如果要对连续型变量进行计算,就需要根据概率论知识把求和改成积分等等,原理是完全一样的。

KL散度

熵的定义

  首先,我们需要了解概率分布的熵的计算。假设随机变量 X X X服从某种概率分布,其有n种取值,则这个概率分布的熵可以这么计算:

H = − ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) H=-\sum_{i = 1}^{n} p\left(x_{i}\right) \log p\left(x_{i}\right) H=i=1np(xi)logp(xi)

如果我们这里的对数是取2为底数,也就是 l o g 2 log_2 log2,那么根据信息论的相关知识,我们可以将熵理解为:我们编码信息所需要的最小比特数(二进制对应单位)。

KL散度的提出

  然后,假设我们现在面临一个问题,就是我们需要使用一个无法处理的概率分布,因此我们使用一个我们可以处理的概率分布来拟合这种分布,这个过程通常会造成一定的信息损失。而我们如何选择一个最适合的分布来最小化信息损失量呢?这就需要设计一种衡量两种概率分布差异的方法,也就是KL散度的提出。
  KL散度的计算方法如下:

D K L ( p ∥ q ) = E [ log ⁡ p ( x i ) − log ⁡ q ( x i ) ] = ∑ i = 1 n p ( x i ) ( l o g ( p ( x i ) ) − l o g ( q ( x i ) ) ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) q ( x i ) ) \begin{aligned} D_{K L}(p \| q) & = E\left[\log p\left(x_{i}\right)-\log q\left(x_{i}\right)\right]\\ &= \sum_{i = 1}^{n}p(x_i)(log(p(x_i)) - log(q(x_i)))\\ &= \sum_{i = 1}^{n}p(x_i)log(\frac {p(x_i)}{q(x_i)}) \end{aligned} DKL(pq)=E[logp(xi)logq(xi)]=i=1np(xi)(log(p(xi))log(q(xi)))=i=1np(xi)log(q(xi)p(xi))

  根据这个公式,KL散度又被称为相对熵,根据刚刚对熵的物理意义的理解,我们这里可以把KL散度理解为:用q分布去编码p分布时所需要的额外信息量占用的比特数。
  有一个值得注意的地方是,KL散度是不对称的,也就是说一般情况下, D K L ( q ∥ p ) D_{KL}(q\|p) DKL(qp) D K L ( p ∥ q ) D_{KL}(p\|q) DKL(pq)是不相等的,这也导致我们不能将KL散度作为距离,因为距离的一个重要特征就是对称性。此外,KL散度还具有非负性,是通过Jensen不等式在凸积分中的命题证明的,这里不作详述,感兴趣的朋友可以自行搜索。

交叉熵

  交叉熵,同样是一个信息论中十分重要的概念,也是我们在做机器学习算法时经常使用的一种损失函数。它的定义如下:

H ( p , q ) = ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) H(p,q) = \sum_{i = 1}^{n}p(x_i)log(q(x_i)) H(p,q)=i=1np(xi)log(q(xi))

  它的物理意义是,用Q分布去编码服从P分布的变量X所需要的期望比特长度。我们再回去看一下KL散度的定义,可以发现:

D K L ( p ∣ ∣ q ) = ∑ i = 1 n p ( x i ) l o g ( p ( x i ) ) − ∑ i = 1 n p ( x i ) l o g ( q ( x i ) ) = − H ( p ( x i ) ) + H ( p , q ) = H ( p , q ) − H ( p ( x i ) ) \begin{aligned} D_{KL}(p||q) &= \sum_{i = 1}^{n}p(x_i)log(p(x_i))-\sum_{i = 1}^{n}p(x_i)log(q(x_i) )\\ &= - H(p(x_i)) + H(p,q)\\ &= H(p,q) - H(p(x_i)) \end{aligned} DKL(p∣∣q)=i=1np(xi)log(p(xi))i=1np(xi)log(q(xi))=H(p(xi))+H(p,q)=H(p,q)H(p(xi))

  也就是说,KL散度可以视为交叉熵减去原分布的信息熵,这也与它们的物理意义完美符合。在机器学习的算法中,为了方便,直接就使用交叉熵作为损失函数,用来最小化预测数据分布和训练集数据分布的差异。


JS散度

JS散度的提出

  前面提过,KL散度是不对称的,这个性质在应用时会面临一些问题,所以就有人提出了JS散度,不但解决了对称问题,而且JS散度的取值范围为[0,1],更利于衡量和判断。
  JS散度的计算也很简单,我们现在假设有 m = 1 2 ( p + q ) m = \frac{1}{2}(p + q) m=21(p+q),则JS散度计算公式为:

D J S ( p ∣ ∣ q ) = 1 2 D K L ( p ∣ ∣ m ) + 1 2 D K L ( q ∣ ∣ m ) D_{JS}(p||q) = \frac{1}{2}D_{KL}(p||m) + \frac{1}{2}D_{KL}(q||m) DJS(p∣∣q)=21DKL(p∣∣m)+21DKL(q∣∣m)
 
  我们代入KL散度的计算公式,再结合 ∑ i = 1 n p ( x i ) \sum_{i = 1}^{n}p(x_i) i=1np(xi) = ∑ i = 1 n q ( x i ) \sum_{i = 1}^{n}q(x_i) i=1nq(xi) = 1,不难得出:

D J S ( P ∥ Q ) = 1 2 ∑ i = 1 n p ( x ) log ⁡ ( p ( x ) p ( x ) + q ( x ) ) + 1 2 ∑ i = 1 n q ( x ) log ⁡ ( q ( x ) p ( x ) + q ( x ) ) + log ⁡ 2 D_{JS}(P \| Q)=\frac{1}{2} \sum_{i = 1}^{n} p(x) \log \left(\frac{p(x)}{p(x)+q(x)}\right)+\frac{1}{2} \sum_{i = 1}^{n} q(x) \log \left(\frac{q(x)}{p(x)+q(x)}\right)+\log 2 DJS(PQ)=21i=1np(x)log(p(x)+q(x)p(x))+21i=1nq(x)log(p(x)+q(x)q(x))+log2

JS散度的缺馅

  大家观察上面最后化简得到的等式,可以思考一下如果p分布和q分布完全没有重叠部分时会出现什么情况。没错,如果p分布和q分布完全不重叠,会导致JS散度为一个定值,就是log2。如果在训练机器学习模型时我们使用了JS散度,很可能在遇到这种情况的时候由于JS散度为常数导致梯度为零无法更新。
  下面我来简单推导一下,假设p和q分布的情况如图所示:
p分布和q分布 
  从图中不难发现,有一块区域是空的,也就是说,在这片区域p分布和q分布都不曾涉足,假设这里面有一个 X k X_k Xk的取值 X k X_k Xk,以它为分界线,我们可以将JS散度表达式分为两部分相加,即 ∑ i = 1 n = ∑ i = 1 k + ∑ i = k + 1 n \sum_{i = 1}^{n} = \sum_{i = 1}^{k} + \sum_{i = k + 1}^{n} i=1n=i=1k+i=k+1n。(公式太长我就用这个简写了)

  当 i < = k i <= k i<=k时, p ( x i ) p(x_i) p(xi) = 0,我们代入JS散度除去log2的部分,可以得出:

1 2 ∑ i = 1 k 0 × log ⁡ ( 0 0 + q ( x i ) ) + 1 2 ∑ i = 1 k q ( x i ) log ⁡ ( q ( x i ) 0 + q ( x i ) ) = 0 \frac{1}{2} \sum_{i = 1}^{k} 0 \times \log \left(\frac{0}{0+q(x_i)}\right)+\frac{1}{2} \sum_{i = 1}^{k} q(x_i) \log \left(\frac{q(x_i)}{0+q(x_i)}\right)=0 21i=1k0×log(0+q(xi)0)+21i=1kq(xi)log(0+q(xi)q(xi))=0

  当 i > k i > k i>k时, q ( x i ) q(x_i) q(xi) = 0,我们再代入这个式子,则有:

1 2 ∑ i = k + 1 n p ( x i ) log ⁡ ( p ( x i ) p ( x i ) + 0 ) + 1 2 ∑ i = k + 1 n 0 × log ⁡ ( 0 p ( x i ) + 0 ) = 0 \frac{1}{2} \sum_{i = k + 1}^{n} p(x_i) \log \left(\frac{p(x_i)}{p(x_i)+0}\right)+\frac{1}{2} \sum_{i = k + 1}^{n} 0 \times \log \left(\frac{0}{p(x_i)+0}\right)=0 21i=k+1np(xi)log(p(xi)+0p(xi))+21i=k+1n0×log(p(xi)+00)=0

  因此,我们把它们加起来,就能够得到 D J S ( p ∥ q ) = l o g 2 D_{JS}(p\|q) = log2 DJS(pq)=log2

  感兴趣的朋友可以看看我的这篇博客😚机器学习系列第一章,以后有机会就更新😁

  • 18
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值