参考自博客
KL 散度的三种估计方法
大家不难发现R1论文中的KL散度和平常的KL散度不一样,实际上它是KL散度的无偏低方差估计,我们今天来回顾下KL散度的几种估计方法。
在概率论和统计学中,Kullback-Leibler (KL) 散度是用来衡量两个概率分布之间差异的重要工具。在机器学习和数据科学的背景下,准确地估计 KL 散度是常见的挑战之一。今天,我想与大家分享三种用于估计 KL 散度的方法,这三种方法在实际使用中都各有优势。
K L [ q , p ] = ∑ x q ( x ) log q ( x ) p ( x ) = E x ∼ q [ log q ( x ) p ( x ) ] KL[q, p] = \sum_x q(x) \log \frac{q(x)}{p(x)} = E_{x \sim q}\left[\log \frac{q(x)}{p(x)}\right] KL[q,p]=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
1. 简单估计(k1)
首先,我们来看最简单和直观的估计方法,直接将采样结果当估计结果,因为p和q位置反了,所以前面有一个负号。对于给定的分布 ( p ) 和 ( q ),其基本的计算公式为:
k 1 = − log ( p ( x ) q ( x ) ) k1 = -\log \left( \frac{p(x)}{q(x)} \right) k1=−log(q(x)p(x))
这个方法的优点是它是无偏的,即在多次取样后,其平均值会趋近于真实的 KL 散度值。然而,这个方法的缺点在于其高方差。在某些情况下,估计值为负,而我们知道KL始终为正。
2. 低方差估计(k2)
为了克服 k1 带来的高方差问题,我们可以使用第二种方法,称为 k2,其计算方式为:
k 2 = 1 2 ( log ( p ( x ) q ( x ) ) ) 2 k2 = \frac{1}{2}\left(\log \left( \frac{p(x)}{q(x)} \right)\right)^2 k2=21(log(q(x)p(x)))2
k2 方法的好处在于它的方差较低,并且始终为正。这是因为它通过平方对比了两个分布的差异,避免了负值对估计的影响。在实验数据中,k2 的偏差也相对较小,通常在0.2%以下,但在真实分布的 KL 散度较大时,偏差可能会增大。
3. 优化估计(k3)
最后,我们介绍第三种方法 k3。该方法通过控制变量降低了方差,并同时保持了无偏性,注意到r-1的期望为0。它的形式为:
k 3 = ( r − 1 ) − log ( r ) k3 = (r - 1) - \log(r) k3=(r−1)−log(r)
而 r = p ( x ) q ( x ) r = \frac{p(x)}{q(x)} r=q(x)p(x)这个估计方法的特点是:通过使用 Bregman 散度的思想,它为每个点的估计提供了额外的信息,从而提高了表现。在大多数情况下,k3 的方差不仅低于 k2,同时它也是无偏的,这使得它在多个场景下成为优选的估计器。
总结
在比较这三种估计方法时,我们发现:
- k1 是简单直接的,但高方差使得它在实践中常常不够稳定。
- k2 虽然降低了方差,但在真实数据中偏差可能会加大。
- k3 则在保持无偏性的同时,进一步减少了方差,展现出卓越的效果。
根据针对 ( p = N(0,1) ) 和 ( q = N(0.1,1) ) 的仿真结果,真实 KL 散度值约为 0.005。对比各个估计器在此情况下的偏差和标准差,不难发现 k3 凭借其出色的性能,似乎成为了最好的选择。因此GRPO实际上使用k3进行的KL散度的估计。