
背景
这里举一个简单的情况,当前我们有p1, p2, p3三个概率,我们需要得到最优的一个即max(p1, p2, p3),例如当前p3 = max(p1, p2, p3),那么理想输出应当为[0, 0, 1],然后应用于下游的优化目标,这种场景在搜索等场景经常出现。
如果暴力的进行clip或者mask操作转化为独热向量的话会导致在梯度反向传播的时候无法更新上游网络。因为p1和p2对应的梯度一定为0。
方法通俗理解
针对上述情况,采用重参数化的思路可以解决。
即然每次前向传播理想情况下是0-1独热向量向量,但同时能保证[p1, p2, p3]这个分布能被根据概率被更新。于是采用了一种重参数化的方法,即从每次都从一个分布中采样一个u,这个u属于一个均匀分布,从这个均匀分布通过转换变成[p1, p2, p3]这个分布。这样就能即保证梯度可以反向传播,同时根据每次采样来实现按照[p1, p2, p3]这个分布更新,而不是每次只能更新最大的一个。
而这种方法就是重参数化。
什么是重参数化
Reparameterization,重参数化,这是一个方法论,是一种技巧。
我们首先可以抽象出来它的数学表达形式:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) \end{equation} Lθ=Ez~pθ′(z)(fθ(z))
注意:在有些时候 θ ′ ∈ θ \theta' \in \theta θ′∈θ或者 θ ′ = θ \theta' = \theta θ′=θ。
如何理解:这里我们的优化目标是 L θ L_{\theta} Lθ,其中 f θ ( ) f_{\theta}() fθ()一般是我们的模型,而计算 z z z是从分布 p θ ′ ( z ) p_{\theta'}(z) pθ′(z)中采样得到的。但是问题是我们不能把一个分布输入到 f θ ( ) f_{\theta}() fθ()中去,只能从选择一个特定的 z z z,但是这样就没法更新 θ ′ \theta' θ′。
综上,重参数化就是从给定分布中采样得到一个 z z z,同时保证了梯度可以更新 θ ′ \theta' θ′,这种保证采样分布和给定分布无损转换的采样策略叫做重参数化。(个人理解,欢迎大佬指正)
由于我们现在解决的是gumbel-softmax问题,所以只关注当 p θ ′ ( z ) p_{\theta'}(z) pθ′(z)是离散的情况下,此时:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) = ∑ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) = \sum p_{\theta'}(z)(f_{\theta}(z)) \end{equation} Lθ=Ez~pθ′(z)(fθ</