Gumbel softmax trick pytorch(快速理解附代码)

也可查看我的知乎: Gumbel softmax trick (快速理解附代码)

(一)目的

在深度学习中,对某一个离散随机变量 X X X进行采样,并且又要保证采样过程是可导的(因为要用梯度下降进行优化,并且用BP进行权重更新),那么就可以用Gumbel softmax trick。属于重参数技巧(re-parameterization)的一种。

首先我们要介绍,什么是Gumbel distribution,然后再介绍怎么用到梯度下降中,最后用pytorch实现它。

(二)什么是Gumbel distribution?

一个简单的例子:

女子高中三年级一共有16个班,从每个班抽30人,那么现在总共有16组 30人的样本

如果看每组样本里面的身高分布,大概率是服从正态分布的。现在,从每组样本里面挑出身高最高的人,将这些人再组成一个新的样本集合,也就是,新的样本集合有16个人。

那么你会发现,这16个人的样本集合就是服从的Gumbel 分布,且是极大值 Gumbel distribution。当然,如果换成抽最矮的人,这个分布就是极小值 Gumbel distribution。

定义

一种极值分布(或者叫做Fisher-Tippett extreme value distributions),顾名思义就是用来研究极值(极大值,或者极小值)的一种概率分布形式。和别的一些分布形式一样,给定一个描述分布的公式,然后再给定公式中的某些参数,那么就确定了这个分布。

(本质就是想用数学语言或公式来逼近或解释现实世界观察到的现象,比如自然界很多现象可以用正太分布来描述,自然地,也存在一些自然现象,要用极值分布来描述。)

下面定义极大值的Gumbel distribution。

CDF:

F ( x ; μ , β ) = e − e − ( x − μ ) β F(x;\mu,\beta)=e^{-e^{- \frac {(x-\mu)}{\beta}}} F(x;μ,β)=eeβ(xμ)

PDF:

f ( x ) = ∂ F ∂ x = 1 β e − ( z + e − z ) , f(x)=\frac{\partial F}{\partial x}=\frac{1}{\beta}e^{-(z+e^{-z})}, f(x)=xF=β1e(z+ez),
where z = x − μ β z=\frac{x-\mu}{\beta} z=βxμ.

标准Gumbel 分布:

即, μ = 0 , β = 1 \mu=0, \beta=1 μ=0,β=1, 则CDF为:
F ( x ; μ = 0 , β = 1 ) = F ( x ) = e − e − ( x ) F(x;\mu=0,\beta=1)=F(x)=e^{-e^{-(x)}} F(x;μ=0,β=1)=F(x)=ee(x)

函数图像:

在这里插入图片描述

(三)什么是Gumbel softmax trick?

Gumbel分布描述了自然界或者说人造的某种数据(其实也是自然界吧,毕竟人也是自然的一部分。)的极值分布的 “规律”(分布其实只是认识”规律“的一种方式)。所以自然地,我们之所以会用到Gumbel分布,就是因为我们要处理的数据中,存在极值分布(~废话)。

考虑如下场景:

对一个离散随机变量 X \mathbf{X} X进行采样,随机变量的取值范围为 { 1 , 2 , . . . , K } \{1,2,...,K\} {1,2,...,K}。首先要知道随机变量的分布函数,这里假设用MLP学习一个K维的向量: h ∈ R K \mathbf{h} \in \mathbb{R}^K hRK

(假如是直接做inference的话,不考虑概率意义,那么我们直接取这个向量元素最大值的下标当做预测的离散变量值就可以了,即, X i = arg ⁡ max ⁡ i h i X_i = \arg\max_i h_i Xi=argmaxihi.,但我们希望的是预测的离散变量具有概率意义,或者说得到的多个预测值的经验分布符合理论的概率分布。否则的话就是deterministic的,会导致某些小概率的变量值根本取不到,进而影响后续的任务。)

所以,我们需要赋予概率意义。通常,我们可以用softmax函数作用到 h \mathbf{h} h求得一个符合概率意义的新概率向量,即:
p i = s o f t m a x ( h , h i ) = e x p ( h i ) ∑ i e x p ( h i ) . p_i=softmax(h,h_i)=\frac{exp(h_i)}{\sum_i exp(h_i)}. pi=softmax(h,hi)=iexp(hi)exp(hi).
这样我们就获得了各个离散取值的概率分布 p ∈ [ 0 , 1 ] K \mathbf{p} \in [0,1]^K p[0,1]K,其中 p i = P r { X i = i } p_i=Pr\{X_i=i\} pi=Pr{Xi=i}。这里 p \mathbf{p} p是一个在K维simplex中的一个向量。

到这里,我们得到了 X X X的概率分布,如果要直接得到离散变量,直接取 X i = arg ⁡ max ⁡ i p i X_i = \arg\max_i p_i Xi=argmaxipi即可。(注意,这里每次inference的时候,取了最大值,是不是和Gumbel分布的含义很像了。)

问题是我们需要的是采样,也就是生成的多个样本的频率分布要符合其理论的概率分布。另外,可以开始考虑,是否能够将求导采样这两个操作解耦。

如果知道一些reparameterization trick的技巧,很容易想到,我们只需要将 p \mathbf{p} p加上一个要学习的参数无关(即无需进行求导)的某个随机变量 g \mathbf{g} g,那么采样过程就可以通过 g \mathbf{g} g进行(曲线救国了算是),这样做相当于把求导采样解耦了。这里,只需要保证结合后的分布,与原分布 p \mathbf{p} p相等或近似即可。

接下来就是与服从Gumbel分布的随机变量 g \mathbf{g} g结合:

X i = arg ⁡ max ⁡ i ( log ⁡ ( p i ) + g i ) X_i = \arg \max_i (\log(p_i) + g_i) Xi=argimax(log(pi)+gi)
这里 g i g_i gi是一个提前采样好的标准Gumbel分布序列。通过这种方法,理论上可以证明,这个新随机变量的分布函数和原分布函数相等。证明见:。。

但这样的问题在于 arg ⁡ max ⁡ ( ) \arg\max () argmax()不可导,导致无法使用梯度下降来更新参数。所以一种办法是将随机变量的取值从 1 , . . . , K {1,...,K} 1,...,K变为用一个K维的one_hot向量编码来表示。比如,本来取 X i = i X_i=i Xi=i,现如果用one_hot来表示的话,就是 X i = ( 0 , . . . , 1 , . . . , 0 ) X_i = (0,...,1,...,0) Xi=(0,...,1,...,0),也就是第 i i i个下标的值为1,其它都为0,我们记第 i i i个下标的值为 y i y_i yi。那么我们就可以用softmax函数来近似这个one_hot向量:

y i = exp ⁡ ( ( log ⁡ ( p i ) + g i ) / τ ) ∑ k = 1 K exp ⁡ ( ( log ⁡ ( p k ) + g k ) / τ ) y_i = \frac{\exp((\log(p_i) + g_i)/\tau)}{\sum_{k=1}^K\exp((\log(p_k) + g_k)/\tau)} yi=k=1Kexp((log(pk)+gk)/τ)exp((log(pi)+gi)/τ)
这里的 τ \tau τ被叫做温度系数,或者说是一个缩放因子。一般来说, τ < 1 \tau < 1 τ<1,想象一下以 e e e为底的指数分布图像,可以发现,如果 τ \tau τ越小,指数的值 e ( x / τ ) 越大,简记 x = l o g ( p i ) + g i e^{(x/\tau)}越大,简记x=log(p_i)+g_i e(x/τ)越大,简记x=log(pi)+gi。也就是说,这个 τ \tau τ存在的意义就是让本来大的 x x x越大,所以会导致 y i y_i yi越接近1,并且 ∀ j ≠ i , y j \forall j \neq i, y_j j=i,yj会接近0,所以 X i X_i Xi就更接近一个one_hot表示。
图片来源于文章[7]:图片来源于文章[7]

(四)如何生成Gumbel分布的样本

最后一步,就是如何生成Gumbel 分布的样本,即,如何产生 g i g_i gi

这里使用最常见的一种方法也就是inverse CDF method。先求出Gumbel的CDF函数 F ( x ; μ , β ) F(x;\mu,\beta) F(x;μ,β)的反函数 x = F − 1 ( y ; μ , β ) = μ − β ln ⁡ ( − ln ⁡ y ) x = F^{-1}(y;\mu,\beta)=\mu - \beta \ln(- \ln y) x=F1(y;μ,β)=μβln(lny)(根据CDF的公式: y = F ( x ; μ , β ) y=F(x;\mu,\beta) y=F(x;μ,β),把y和x反过来表示就可),然后只要生成 y ∼ U n i f o r m ( 0 , 1 ) y \sim Uniform(0,1) yUniform(0,1)的均匀分布的序列,那么相应的 x x x就服从Gumbel分布, x ∼ G u m b e l ( μ , β ) x \sim Gumbel(\mu, \beta) xGumbel(μ,β),也即, x x x的CDF函数为原来的 F ( x ) F(x) F(x)证明如下:
P ( F − 1 ( y ) ≤ x ) = P ( y ≤ F ( x ) ) = ∫ 0 F ( x ) p d f ( y ) d y = ∫ 0 F ( x ) 1 d y = F ( x ) P(F^{-1}(y) \leq x)= P(y \leq F(x))=\int_0^{F(x)}pdf(y)dy=\int_0^{F(x)}1dy=F(x) P(F1(y)x)=P(yF(x))=0F(x)pdf(y)dy=0F(x)1dy=F(x)


到这里我们就可以通过以上的公式进行采样了。

(五)pytorch实现

下面用pytorch实现一下上面描述的采样过程。

# Gumbel softmax trick:

import torch
import torch.nn.functional as F
import numpy as np

def inverse_gumbel_cdf(y, mu, beta):
    return mu - beta * np.log(-np.log(y))

def gumbel_softmax_sampling(h, mu=0, beta=1, tau=0.1):
    """
    h : (N x K) tensor. Assume we need to sample a NxK tensor, each row is an independent r.v.
    """
    shape_h = h.shape
    p = F.softmax(h, dim=1)
    y = torch.rand(shape_h) + 1e-25  # ensure all y is positive.
    g = inverse_gumbel_cdf(y, mu, beta)
    x = torch.log(p) + g  # samples follow Gumbel distribution.
    # using softmax to generate one_hot vector:
    x = x/tau
    x = F.softmax(x, dim=1)  # now, the x approximates a one_hot vector.
    return x

N = 10  # 假设 有N个独立的离散变量需要采样
K = 3   # 假设 每个离散变量有3个取值
h = torch.rand((N, K))  # 假设 h是由一个神经网络输出的tensor。

mu = 0
beta = 1
tau = 0.1

samples = gumbel_softmax_sampling(h, mu, beta, tau)

References

  1. https://mathworld.wolfram.com/GumbelDistribution.html
  2. https://www.itl.nist.gov/div898/handbook/eda/section3/eda366g.htm
  3. https://en.wikipedia.org/wiki/Fisher%E2%80%93Tippett%E2%80%93Gnedenko_theorem
  4. https://en.wikipedia.org/wiki/Gumbel_distribution
  5. https://www.cnblogs.com/initial-h/p/9468974.html
  6. https://arxiv.org/pdf/1611.04051.pdf
  7. https://arxiv.org/abs/1611.01144
### PyTorch `nonzero` 函数不可微的原因 在 PyTorch 中,`torch.nonzero` 返回的是张量中非零元素的位置索引。由于这些位置是离散的选择结果而非连续数值输出,因此无法定义梯度流经此操作的方式[^1]。 对于需要保持可导的操作场景来说,直接应用 `nonzero` 可能会中断反向传播链路,从而阻碍模型训练过程中参数更新的有效性。 ### 替代方案 为了实现既能够获取满足条件的元素又不破坏计算图连贯性的目的,可以考虑如下几种方式: #### 使用 Softmax 和 Log-Sum-Exp 巧妙组合 通过构建一个基于 softmax 的掩码来近似指示哪些位置应该被选中,这种方法可以在一定程度上保留求导路径的同时达到筛选的效果。 ```python import torch from torch import nn def soft_nonzero(x, temperature=0.1): mask = (x != 0).float() exp_x = torch.exp(mask / temperature) sum_exp = torch.sum(exp_x, dim=-1, keepdim=True) return exp_x / sum_exp ``` 这里引入了一个温度超参数用于控制逼近程度,在实际部署时可根据具体需求调整该值以平衡准确率与平滑度之间的关系。 #### 利用 Smooth Approximation 技术 另一种思路是对原始逻辑表达式做光滑化处理,比如采用 Sigmoid 或 Tanh 来代替硬阈值判断,这样做的好处是可以让整个过程变得处处可导。 ```python def smooth_nonzero(x, k=50.0): return torch.sigmoid(k * x) ``` 上述代码片段展示了利用 sigmoid 函数作为激活门限的例子,其中系数 \(k\) 控制着曲线陡峭的程度,较大数值可以使输出更接近于二元判定行为。 #### 基于 Gumbel-Max Trick 实现不同iable Sampling 当面对多分类或多标签选择问题时,Gumbel-max trick 提供了一种优雅的方法来进行概率抽样并维持可微特性。 ```python class DifferentiableNonZero(nn.Module): def __init__(self, tau=1., hard=False): super().__init__() self.tau = tau self.hard = hard def forward(self, logits): gumbels = -(-((-(logits + 1e-20)).exp())).log() # Sample from Gumbel(0, 1) noisy_logits = (logits.log_softmax(dim=-1) + gumbels) / self.tau y_soft = noisy_logits.softmax(dim=-1) if not self.training or not self.hard: return y_soft index = y_soft.max(-1, keepdim=True)[1] y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.) ret = y_hard - y_soft.detach() + y_soft return ret ``` 这段实现了带有软采样的机制,允许在网络推理阶段获得确切的一热编码表示形式,而在训练期间则遵循松弛后的分布规律变化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值