gumbel-softmax如何实现离散分布可微+torch代码+原理+证明


在这里插入图片描述

背景

这里举一个简单的情况,当前我们有p1, p2, p3三个概率,我们需要得到最优的一个即max(p1, p2, p3),例如当前p3 = max(p1, p2, p3),那么理想输出应当为[0, 0, 1],然后应用于下游的优化目标,这种场景在搜索等场景经常出现。
如果暴力的进行clip或者mask操作转化为独热向量的话会导致在梯度反向传播的时候无法更新上游网络。因为p1和p2对应的梯度一定为0。

方法通俗理解

针对上述情况,采用重参数化的思路可以解决。
即然每次前向传播理想情况下是0-1独热向量向量,但同时能保证[p1, p2, p3]这个分布能被根据概率被更新。于是采用了一种重参数化的方法,即从每次都从一个分布中采样一个u,这个u属于一个均匀分布,从这个均匀分布通过转换变成[p1, p2, p3]这个分布。这样就能即保证梯度可以反向传播,同时根据每次采样来实现按照[p1, p2, p3]这个分布更新,而不是每次只能更新最大的一个。
而这种方法就是重参数化。

什么是重参数化

Reparameterization,重参数化,这是一个方法论,是一种技巧。
我们首先可以抽象出来它的数学表达形式:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))
注意:在有些时候 θ ′ ∈ θ \theta' \in \theta θθ或者 θ ′ = θ \theta' = \theta θ=θ
如何理解:这里我们的优化目标是 L θ L_{\theta} Lθ,其中 f θ ( ) f_{\theta}() fθ()一般是我们的模型,而计算 z z z是从分布 p θ ′ ( z ) p_{\theta'}(z) pθ(z)中采样得到的。但是问题是我们不能把一个分布输入到 f θ ( ) f_{\theta}() fθ()中去,只能从选择一个特定的 z z z,但是这样就没法更新 θ ′ \theta' θ
综上,重参数化就是从给定分布中采样得到一个 z z z,同时保证了梯度可以更新 θ ′ \theta' θ,这种保证采样分布和给定分布无损转换的采样策略叫做重参数化。(个人理解,欢迎大佬指正)

由于我们现在解决的是gumbel-softmax问题,所以只关注当 p θ ′ ( z ) p_{\theta'}(z) pθ(z)是离散的情况下,此时:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) = ∑ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) = \sum p_{\theta'}(z)(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ</

### Gumbel-Softmax算子定义 Gumbel-Softmax算子是一种用于离散随机变量分布参数化的技术,允许通过梯度下降优化这些变量。该方法引入了一个可分的松弛版本来近似离散采样过程,在保持样本接近原始离散空间的同时使得反向传播成为可能[^1]。 具体来说,对于一个具有类别概率 \( \pi_1,\ldots ,\pi_k \) 的多项式分布而言,可以通过添加来自Gumbel(0, 1)分布的噪声并应用softmax函数来进行重参数化操作: \[ z_i=\frac{\exp((g_i+\log(\pi _i))/\tau )}{\sum_{j}\exp ((g_j+\log (\pi _j))/\tau)} \] 其中\( g_i=-\log(-\log(u)), u∼U(0,1)\),而温度超参τ控制着输出分布锐利程度;当τ趋近于零时,得到的结果更倾向于硬性分配给某个特定类别的one-hot编码形式;反之则趋向于均匀分布。 ### 实现代码示例 下面是一个简单的Python实现例子,展示了如何利用PyTorch库中的功能构建Gumbel Softmax层: ```python import torch from torch import nn class GumbelSoftmax(nn.Module): def __init__(self, tau=1.0, hard=False): super(GumbelSoftmax, self).__init__() self.tau = tau self.hard = hard def sample_gumbel(self, shape, eps=1e-20): U = torch.rand(shape).cuda() return -torch.log(-torch.log(U + eps) + eps) def forward(self, logits): y = logits + self.sample_gumbel(logits.size()) y = F.softmax(y / self.tau, dim=-1) if not self.training or not self.hard: return y shape = y.size() _, ind = y.max(dim=-1) y_hard = torch.zeros_like(y).view(-1, shape[-1]) y_hard.scatter_(1, ind.view(-1, 1), 1) y_hard = y_hard.view(*shape) return (y_hard - y).detach() + y ``` 此模块接收未标准化的日志几率作为输入,并返回经过Gumbel Softmax变换后的张量。如果设置`hard=True`,那么在训练期间会采用直通估计器(straight-through estimator)的方式获得离散的选择结果,而在评估模式下总是给出软性的加权平均表示。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值