gumbel-softmax如何实现离散分布可微+torch代码+原理+证明


在这里插入图片描述

背景

这里举一个简单的情况,当前我们有p1, p2, p3三个概率,我们需要得到最优的一个即max(p1, p2, p3),例如当前p3 = max(p1, p2, p3),那么理想输出应当为[0, 0, 1],然后应用于下游的优化目标,这种场景在搜索等场景经常出现。
如果暴力的进行clip或者mask操作转化为独热向量的话会导致在梯度反向传播的时候无法更新上游网络。因为p1和p2对应的梯度一定为0。

方法通俗理解

针对上述情况,采用重参数化的思路可以解决。
即然每次前向传播理想情况下是0-1独热向量向量,但同时能保证[p1, p2, p3]这个分布能被根据概率被更新。于是采用了一种重参数化的方法,即从每次都从一个分布中采样一个u,这个u属于一个均匀分布,从这个均匀分布通过转换变成[p1, p2, p3]这个分布。这样就能即保证梯度可以反向传播,同时根据每次采样来实现按照[p1, p2, p3]这个分布更新,而不是每次只能更新最大的一个。
而这种方法就是重参数化。

什么是重参数化

Reparameterization,重参数化,这是一个方法论,是一种技巧。
我们首先可以抽象出来它的数学表达形式:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))
注意:在有些时候 θ ′ ∈ θ \theta' \in \theta θθ或者 θ ′ = θ \theta' = \theta θ=θ
如何理解:这里我们的优化目标是 L θ L_{\theta} Lθ,其中 f θ ( ) f_{\theta}() fθ()一般是我们的模型,而计算 z z z是从分布 p θ ′ ( z ) p_{\theta'}(z) pθ(z)中采样得到的。但是问题是我们不能把一个分布输入到 f θ ( ) f_{\theta}() fθ()中去,只能从选择一个特定的 z z z,但是这样就没法更新 θ ′ \theta' θ
综上,重参数化就是从给定分布中采样得到一个 z z z,同时保证了梯度可以更新 θ ′ \theta' θ,这种保证采样分布和给定分布无损转换的采样策略叫做重参数化。(个人理解,欢迎大佬指正)

由于我们现在解决的是gumbel-softmax问题,所以只关注当 p θ ′ ( z ) p_{\theta'}(z) pθ(z)是离散的情况下,此时:
L θ = E z ~ p θ ′ ( z ) ( f θ ( z ) ) = ∑ p θ ′ ( z ) ( f θ ( z ) ) \begin{equation} L_{\theta} = E_{z~p_{\theta'}(z)}(f_{\theta}(z)) = \sum p_{\theta'}(z)(f_{\theta}(z)) \end{equation} Lθ=Ezpθ(z)(fθ(z))=pθ(z)(fθ(z))
这也就是gumbel-softmax要解决的数学形式。

gumbel-softmax

gumbel-softmax给出的采样方案,叫做gumbel max:
从原来的 a r g m a x i ( [ p 1 , p 2 , . . . ] ) argmax_i([p1, p2, ...]) argmaxi([p1,p2,...]) a r g m a x i ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) , ϵ i ∈ U [ 0 , 1 ] argmax_i(log(p_i)-log(-log(\epsilon_i))), \epsilon_i \in U[0, 1] argmaxi(log(pi)log(log(ϵi))),ϵiU[0,1]

也就是先算出各个概率的对数 l o g ( p i ) log(p_i) log(pi),然后从均匀分布 U U U中采样随机数 ϵ i \epsilon_i ϵi,把 − l o g ( − l o g ϵ i ) −log(−log\epsilon_i) log(logϵi)加到 l o g ( p i ) log(p_i) log(pi),然后再进行后续操作。
这里可以理解为通过 ϵ \epsilon ϵ的采样将随机性增加。有的人会疑问,为什么格式变得这么复杂,各种算log,这是为什么?这个就涉及到下一节了,具体原因就是来保证数学的变换正确性,即我增加了随机性,但是保证分布的期望仍然是和原始[p1, p2, p3]是一致的,这个证明在下一节,是有比较严谨的数学证明的。

但是这里还有一个问题,就是argmax或者说onehot操作仍然会丢失梯度,所以采用带超参 τ \tau τ的softmax,来进行平滑:
s o f t m a x ( ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) / τ ) \begin{equation} softmax((log(p_i)-log(-log(\epsilon_i)))/\tau) \end{equation} softmax((log(pi)log(log(ϵi)))/τ)
其中 τ \tau τ也被称为退火参数,用来调整平滑的程度: τ \tau τ越小,越接近onhot向量。

这里也解释清楚了所谓gumbel-softmax是通过gumbel max实现重参数化,通过带退火参数的softmax实现梯度反向传递。

为什么是gumbel

这就涉及到一个gumbel max的证明了。
目标是证明针对 l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) log(p_i)-log(-log(\epsilon_i)) log(pi)log(log(ϵi)),当 a r g m a x i ( l o g ( p i ) − l o g ( − l o g ( ϵ i ) ) ) = 1 argmax_i(log(p_i)-log(-log(\epsilon_i))) = 1 argmaxi(log(pi)log(log(ϵi)))=1时,其概率为 p 1 p_1 p1

假设:
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) log(p_1)-log(-log(\epsilon_1)) log(p1)log(log(ϵ1)) 最大

则:
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 2 ) − l o g ( − l o g ( ϵ 2 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_2)-log(-log(\epsilon_2)) log(p1)log(log(ϵ1))>log(p2)log(log(ϵ2))
l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 3 ) − l o g ( − l o g ( ϵ 3 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_3)-log(-log(\epsilon_3)) log(p1)log(log(ϵ1))>log(p3)log(log(ϵ3))

l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) > l o g ( p 2 ) − l o g ( − l o g ( ϵ 2 ) ) log(p_1)-log(-log(\epsilon_1)) > log(p_2)-log(-log(\epsilon_2)) log(p1)log(log(ϵ1))>log(p2)log(log(ϵ2)) ->
ϵ 1 p 2 / p 1 > ϵ 2 \epsilon_1^{p_2/p_1} > \epsilon_2 ϵ1p2/p1>ϵ2
所以: p 1 p1 p1 大于 p 2 p_2 p2 的概率是 ϵ 1 p 2 / p 1 \epsilon_1^{p_2/p_1} ϵ1p2/p1

同理:
p 1 p1 p1 大于 p 3 p_3 p3 的概率是 ϵ 1 p 3 / p 1 \epsilon_1^{p_3/p_1} ϵ1p3/p1

所以 l o g ( p 1 ) − l o g ( − l o g ( ϵ 1 ) ) log(p_1)-log(-log(\epsilon_1)) log(p1)log(log(ϵ1)) 最大的概率是:
ϵ 1 p 2 / p 1 \epsilon_1^{p_2/p_1} ϵ1p2/p1 * ϵ 1 p 3 / p 1 \epsilon_1^{p_3/p_1} ϵ1p3/p1 * … = ϵ 1 ( 1 − p 1 ) / p 1 \epsilon_1^{(1-p_1)/p_1} ϵ1(1p1)/p1
E ( ϵ 1 ( 1 − p 1 ) / p 1 ) E(\epsilon_1^{(1-p_1)/p_1}) E(ϵ1(1p1)/p1) = ∫ 0 1 ϵ 1 ( 1 − p 1 ) / p 1 d ϵ \int_{0}^{1}\epsilon_1^{(1-p_1)/p_1} d\epsilon 01ϵ1(1p1)/p1dϵ = ∫ 0 1 ϵ 1 ( 1 / p 1 ) − 1 d ϵ 1 \int_{0}^{1}\epsilon_1^{(1/p_1)-1} d\epsilon_1 01ϵ1(1/p1)1dϵ1 = ( p 1 ( ϵ 1 1 / p 1 ) ) ∣ 0 1 (p_1(\epsilon_1^{1/p_1}))|^1_0 (p1(ϵ11/p1))01 = p 1 p_1 p1

证明假设成立

torch实现

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature=0.5):
    y = torch.log(logits) + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature=1, hard=False):
    """
    input: [B, n_class]
    return: [B, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature)
    
    if not hard:
        return y

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard

思考

为什么gumbel-softmax和softmax的输出是不一样的?
为什么argmax(gumbel-softmax) 和 argmax(softmax)的结果也不一定一样?这是正常的吗?

大家共勉~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值