带你认识神奇的Gumbel trick

The Gumbel soft-max

Gumbel trick有两个用途,一个用途是是用来对离散分布进行采样,这是一种重参数化(reparameterization trick)的技巧,另外一个用途是用于估计normalizing partition function,也就是分布的归一化项。本文将介绍这两种方法的原理。

下面是一个使用gumbel trick来模拟离散分布采样的例子:
这里写图片描述

这里写图片描述

如上图例子,首先有 log ⁡ α 1 \log \alpha_1 logα1,这可以看做是一个多项式分布的概率值, 然后加上一个gumbel noise G1,最后取最大值,就是我们要的离散分布的样本,样本出现的概率跟 α \alpha α成正比。这个过程可以形式化为,设X是离散随机分布 P ( X = k ) ∝ α k P(X=k)\propto \alpha_k P(X=k)αk , 设 { G k } k ≤ K \{G_k\}_{k\le K} {Gk}kK是独立同分布的Gumbel 分布的随机变量。于是:
X = arg ⁡ max ⁡ k ( log ⁡ ( α k ) + G k ) X=\arg\max_k(\log(\alpha_k)+G_k) X=argkmax(log(αk)+Gk)

为了让这个argmax可求导,于是就把中间的argmax换成softmax。
我们从这个图底下的“+”号可以看到,这是一种重参数的方法,通过加一个随机的,固定分布的噪声,从而实现采样。这个噪声的采样方法可以通过Inverse transform sampling方法直接从均匀分布进行采样,即
G i ∼ − l o g ( − l o g ( Uniform ( 0 , 1 ) ) ) G_i\sim -log(-log(\text{Uniform}(0,1))) Gilog(log(Uniform(0,1)))

目前一篇论文“Categorical Reparametrization with Gumbel-Softmax ”正是用了这个方法去对离散的隐状态进行采样,从而使得里面的参数可导。

Gumbel distribution

要想知道他为什么有这样的效果,我们需要先介绍一下gumbel distribution

这一个分布,可以把看作是一个关于“最大值”的概率的分布,比如你想预测明年河流最大的水位是多少,那么你就可以用gumbel分布去预测,这个分布会告诉你每一个值作为“最大值“的概率是多少。一个很简单的推广,如果你对这个分布取个负号的话,你就可以去预测最小值。
这里写图片描述

他的概率密度函数:
f ( x ) = 1 β e − ( z + e − z ) f(x) = {\frac {1}{\beta }}e^{-(z+e^{-z})} f(x)=β1e(z+ez)
其中 z = x − μ β z=\frac{x-\mu}{\beta} z=βxμ

他的分布函数:
F ( x ) = e − e − ( x − μ ) / β F(x)= e^{-e^{-(x-\mu )/\beta }} F(x)=ee(xμ)/β
均值: E ( X ) = μ + c β E(X)=\mu+c\beta E(X)=μ+cβ,方差:$ {\frac {\pi ^{2}}{6}}\beta ^{2} , 其 中 ,其中 ,c$是一个常数( Euler–Mascheroni constant )

Gumbel trick用于估计归一化项

我们先考虑一下,求解normalizing partition function. 就是分布的归一化项的问题。

定义一个非标准化的mass function p ~ : X → [ 0 , ∞ ) \tilde{p} : \mathcal{X} \to [0, \infty) p~:X[0,) 这个分布是没有标准化的,也就是他加起来不等于1.而它的标准化项normalizing partition function为 Z : = ∑ x ∈ X p ~ ( x ) Z:= \sum_{x \in \mathcal{X}} \tilde{p}(x) Z:=xXp~(x),接来下我们定义 ϕ ( x ) = ln ⁡ p ~ ( x ) \phi(x)=\ln \tilde{p}(x) ϕ(x)=lnp~(x) 对其概率密度取对数。

于是可以证明:
max ⁡ x ∈ X { ϕ ( x ) + γ ( x ) } ∼ Gumbel ( − c + ln ⁡ Z ) \max_{x \in \mathcal{X}} \{ \phi(x) + \gamma(x) \} \sim \text{Gumbel}(-c + \ln Z) xXmax{ϕ(x)+γ(x)}Gumbel(c+lnZ)
其中 γ ∼ Gumbel ( − c ) \gamma \sim \text{Gumbel}(-c) γGumbel(c)。这就意味,只要我们从 max ⁡ x ∈ X { ϕ ( x ) + γ ( x ) } \max_{x \in \mathcal{X}} \{ \phi(x) + \gamma(x) \} maxxX{ϕ(x)+γ(x)}中采集足够多的样本,我们就能够知道Z的取值(通过求期望得到)。

具体的推导过程如下:
T = max ⁡ x ∈ X { ϕ ( x ) + γ ( x ) } T=\max_{x\in \mathcal{X}} \{\phi (x)+\gamma (x)\} T=maxxX{ϕ(x)+γ(x)},于是他的概率分布等于
P ( T < t ) = P ( max ⁡ x ∈ X { ϕ ( x ) + γ ( x ) } < t ) = ∏ x ∈ X P ( ϕ ( x ) + γ ( x ) < t ) ( 最 大 值 小 于 t 等 价 于 每 一 项 都 小 于 t ) = ∏ x ∈ X P ( γ ( x ) < t − ϕ ( x ) ) = ∏ x ∈ X F G u m b e l ( t − ϕ ( x ) ) = exp ⁡ ( − ∑ x ∈ X exp ⁡ ( − ( t − ϕ ( x ) + c ) ) ) = exp ⁡ ( − Z exp ⁡ ( − ( t + c ) ) ) = exp ⁡ ( − exp ⁡ ( − ( t + c − ln ⁡ Z ) ) ) ⇒ F ( t )  where  t ∼ Gumbel ( − c + ln ⁡ Z ) \begin{aligned} P(T< t) & =P(\max_{x\in \mathcal{X}} \{\phi (x)+\gamma (x)\}< t)\\ & =\prod _{x\in \mathcal{X}} P(\phi (x)+\gamma (x)< t)( 最大值小于t等价于每一项都小于t)\\ & =\prod _{x\in \mathcal{X}} P(\gamma (x)< t-\phi (x))\\ & =\prod _{x\in \mathcal{X}} F_{Gumbel} (t-\phi (x))\\ & =\exp\left( -\sum _{x\in \mathcal{X}}\exp( -(t-\phi (x)+c))\right)\\ & =\exp( -Z\exp( -(t+c)))\\ & =\exp( -\exp( -(t+c-\ln Z)))\\ & \Rightarrow F(t)\text{ where } t\sim \text{Gumbel} (-c+\ln Z) \end{aligned} P(T<t)=P(xXmax{ϕ(x)+γ(x)}<t)=xXP(ϕ(x)+γ(x)<t)(tt)=xXP(γ(x)<tϕ(x))=xXFGumbel(tϕ(x))=exp(xXexp((tϕ(x)+c)))=exp(Zexp((t+c)))=exp(exp((t+clnZ)))F(t) where tGumbel(c+lnZ)
我们发现这个max的函数,最后是服从 Gumbel ( − c + ln ⁡ Z ) \text{Gumbel} (-c+\ln Z) Gumbel(c+lnZ)分布的,也就是说,我们只要求这个分布的期望: E = − c + ln ⁡ Z + c = ln ⁡ Z E=-c+\ln Z+c=\ln Z E=c+lnZ+c=lnZ就可以把 ln ⁡ Z \ln Z lnZ还原出来!这个例子也从侧面说明了Gumbel分布用于表示最大值的概率分布的优势所在(这里优势我觉得直观来看体现在,因为max就相当于所有小于某个数的概率连乘,而因为gumbel分布的指数项的性质,所以,连乘之后指数项没有消失,从而还是服从gumbel分布)。

为什么Gumbel trick能够模拟多项式分布采样?

如果我们的p是已经标准化的p,那么Z=1,于是,这个分布只与 γ ( x ) \gamma(x) γ(x)有关。实际上,当 γ ∼ G u m b e l ( 0 , 1 ) \gamma \sim Gumbel(0,1) γGumbel(0,1),而p是多项式分布的时候就是我们模拟多项式分布进行采样时所服从的分布!那么为什么这个Gumbel 分布能够模拟多项式分布?

我们来考虑一个问题,对于公式1,多项式一共有K个类别。那么第k个类别恰好是最大的概率是多少?

z k = log ⁡ α k + G k \displaystyle z_{k} =\log \alpha _{k} +G_{k} zk=logαk+Gk要求解这个问题,我们要先求出 z k z_k zk是最大的概率多少?然后再对z积分,从而求出第k个是最大的概率。
Pr ⁡ ( log ⁡ α k + G k > max ⁡ i ≠ k log ⁡ α i + G i ) = Pr ⁡ ( max ⁡ i ≠ k log ⁡ α i + G i < log ⁡ α k + G k ) = ∏ i ≠ k Pr ⁡ ( log ⁡ α i + G i < log ⁡ α k + G k ) = ∏ i ≠ k Pr ⁡ ( G i < log ⁡ α k + G k − log ⁡ α i ) = ∏ i ≠ k F ( log ⁡ α k + G k − log ⁡ α i ) = ∏ i ≠ k exp ⁡ { − exp ⁡ { − ( log ⁡ α k + G k − log ⁡ α i ) } } \begin{aligned} \Pr (\log \alpha _{k} +G_{k} >\max_{i\neq k}\log \alpha _{i} +G_{i} ) & =\Pr (\max_{i\neq k}\log \alpha _{i} +G_{i} < \log \alpha _{k} +G_{k} )\\ & =\prod _{i\neq k}\Pr (\log \alpha _{i} +G_{i} < \log \alpha _{k} +G_{k} )\\ & =\prod _{i\neq k}\Pr (G_{i} < \log \alpha _{k} +G_{k} -\log \alpha _{i} )\\ & =\prod _{i\neq k} F(\log \alpha _{k} +G_{k} -\log \alpha _{i})\\ & =\prod _{i\neq k}\exp\{-\exp\{-(\log \alpha _{k} +G_{k} -\log \alpha _{i})\}\} \end{aligned} Pr(logαk+Gk>i=kmaxlogαi+Gi)=Pr(i=kmaxlogαi+Gi<logαk+Gk)=i=kPr(logαi+Gi<logαk+Gk)=i=kPr(Gi<logαk+Gklogαi)=i=kF(logαk+Gklogαi)=i=kexp{exp{(logαk+Gklogαi)}}

现在我们有了 z k \displaystyle z_{k} zk是最大的那个概率值,现在我们想知道第k个元素是最大的概率值是多少,因此,我们需要对所有z的取值进行积分,从而得到第k个位置取值最大的概率。
Pr ⁡ ( k is largest  ∣   { x k ′ } ) = ∫ exp ⁡ { − ( z k − log ⁡ α k ) − exp ⁡ { − ( z k − log ⁡ α k ) } } ∏ i ≠ k exp ⁡ { − exp ⁡ { − ( z k − log ⁡ α i ) } }   d z k = ∫ exp ⁡ { − z k + log ⁡ α k − exp ⁡ { − z k } ∑ i = 1 K exp ⁡ { log ⁡ α i } }   d z k = exp ⁡ { log ⁡ α k } ∑ i = 1 K exp ⁡ { log ⁡ α i } \begin{aligned} \Pr (\text{k is largest} \ |\ \{x_{k'} \}) & =\int \exp \{-(z_{k} -\log \alpha _{k} )-\exp \{-(z_{k} -\log \alpha _{k} )\}\} \prod _{i\neq k}\exp \{-\exp \{-(z_{k} -\log \alpha _{i} )\}\}\ \mathrm{d} z_{k}\\ & =\int \exp \{-z_{k} +\log \alpha _{k} -\exp \{-z_{k} \}\sum ^{K}_{i=1}\exp \{\log \alpha _{i} \}\}\ \mathrm{d} z_{k}\\ & =\frac{\exp \{\log \alpha _{k} \}}{\sum ^{K}_{i=1}\exp \{\log \alpha _{i} \}} \end{aligned} Pr(k is largest  {xk})=exp{(zklogαk)exp{(zklogαk)}}i=kexp{exp{(zklogαi)}} dzk=exp{zk+logαkexp{zk}i=1Kexp{logαi}} dzk=i=1Kexp{logαi}exp{logαk}

这时候,奇迹来了,上面这条等式恰好是一个softmax的公式,也就是说,第k个位置最大的概率,恰好就是对离散概率分布的一个近似。而且一个有趣的性质是这里的 α k \alpha_k αk是不需要归一化的,因为经过softmax之后他就自动归一化了!

参考资料

https://en.wikipedia.org/wiki/Gumbel_distribution
https://irenechen.net/2017/08/the-gumbel-trick/
https://www.youtube.com/watch?v=wVkLM2KKHp8
https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/

  • 8
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值