GumbelSoftmax感性理解--可导式输出随机类别

GumbelSoftmax

本文不涉及GumbelSoftmax的具体证明和推导,有需要请参见1,只是从感性角度来直观讲解为何要引入GumbelSoftmax,同时又为什么不用Gumbelmax。

 GumbelSoftmax提出是为了应对分布采样不可导的问题。举例而言,我们从网络经Softmax层输出了类别概率向量 p 1 = [ 0.9 , 0.1 , 0.1 ] p_1=[0.9,0.1,0.1] p1=[0.9,0.1,0.1] p 2 = [ 0.5 , 0.2 , 0.3 ] p_2=[0.5,0.2,0.3] p2=[0.5,0.2,0.3],那么如果我们训练网络最终的输出需求只是从中得到对应的类别结果(分类任务),那么 p 1 p_1 p1 p 2 p_2 p2其实都是合理的,因为我们我们最终得到的都只会是 a r g m a x ( p ) = 0 argmax(p)=0 argmax(p)=0。但如果我们正在进行生成任务,这一类别结果只是一个中间值,而我们希望这一类别概率向量真正体现出了概率的含义,那么 p 1 , p 2 p_1,p_2 p1p2就会有着显著的差异,后者采样出第1、2类的的结果要明显高于前者。
 因此为了突出网络输出的概率属性,我们可以简单的依照这一概率向量进行采样即可,定一个均匀分布 U ( 0 , 1 ) U(0,1) U(0,1),落在哪个概率区间就认为输出哪一个类别,但这一采样操作是不可导的,也就无法使网络端到端训练。GumbelSoftmax的提出就是为了解决这一问题,它让网络输出类别随机的同时,又使得这一采样过程可导。一句话总结:GumbelSoftmax代替了网络中的 a r g m a x argmax argmax,引入了:

  1. 随机性:网络的输出真的变成了由最终概率向量决定的随机变量,即logit输出 [ 0.9 , 0.1 , 0.1 ] [0.9,0.1,0.1] [0.9,0.1,0.1]真的可能因抽样而判定为第2类;
  2. 可导性:这一抽样过程可导,可以融入到网络端到端训练过程中。(伪)

Gumbelmax

 为了让网络的输出类别真正的随机,我们需要先将对 a r g m a x argmax argmax进行替换,既然网络输出随机的就不可导的话,我们就利用重参数技巧将这一随机性放到另一个随机变量上,也就得到了Gumbelmax,公式如下:
x = a r g m a x ( l o g ( x ) + G ) , \bold{x}=argmax(log(\bold{x})+\bold{G}), x=argmax(log(x)+G),
其中 x , G \bold{x},\bold{G} x,G分别是网络输出的概率向量、符合Gumble分布的噪声向量, G i = − l o g ( − l o g ( U i ) ) , U i   U ( 0 , 1 ) G_i=-log(-log(U_i)),U_i~U(0,1) Gi=log(log(Ui)),Ui U(0,1)。这一噪声向量的引入就会使得argmax的输出结果发生扰动,变成一个随机变量。同样是之前的例子, l o g ( p 1 ) + G log(p_1)+\bold{G} log(p1)+G就有可能变为 [ 0.5 , 0.6 , 0.5 ] [0.5,0.6,0.5] [0.5,0.6,0.5]而使得最终输出类别为第1类,而 a r g m a x ( l o g ( x ) + G ) argmax(log(\bold{x})+\bold{G}) argmax(log(x)+G)服从这一随机变量服从 x x x的离散分布列证明见附1
 通过引入Gumbelmax,我们成功的为网络的类别输出引入了随机性。但可导性的问题并没有解决,因为这里仍然是存在了argmax。

GumbelSoftmax

 GumbelSoftmax对Gumbelmax的解决也很简单,它又把argmax替换成为了softmax,得到如下计算:
x = s o f t m a x ( ( l o g ( x ) + G ) / τ ) , \bold{x}=softmax((log(\bold{x})+\bold{G})/\tau), x=softmax((log(x)+G)/τ),
其中 τ \tau τ为为温度参数,这一算式中通过对argmax的软化实现了可导操作。至此,也就完成了为了网络输出引入可导随机性的目标。

矛盾

 讨论至此,有个非常反直觉的考量,那就是相比于Gumbelmax的硬输出onehot向量,GumbelSoftmax的输出似乎又变成了概率向量,我们想要得到的具体的类别输出,还要继续再取argmax也就是 a r g m a x ( s o f t m a x ( ( l o g ( x ) + G ) ) / τ ) argmax(softmax((log(\bold{x})+\bold{G}))/\tau) argmax(softmax((log(x)+G))/τ)。那么这不是仍然不可导,仍然返回了Gumbelmax的窘境?因此这里依据个人理解要做出以下的澄清:

  1. 确实不可导,如果我们希望从GumbelSoftmax输出一个类别值,那么就必然引入argmax,也就必然不可导。而在实际过程中,我们则是回避了对argmax求导的问题,把它当成一个identity的操作,直接对 s o f t m a x ( ( l o g ( x ) + G ) ) / τ softmax((log(\bold{x})+\bold{G}))/\tau softmax((log(x)+G))/τ进行求导,具体可以参见pytorch中GumbelSoftmax的实现2
  2. 既然如此,那为什么不照猫画虎在使用Gumbelmax的时候就忽略argmax的存在,直接对 ( l o g ( x ) + G ) (log(\bold{x})+\bold{G}) (log(x)+G)求导?这是因为 a r g m a x ( l o g ( x ) + G ) argmax(log(\bold{x})+\bold{G}) argmax(log(x)+G)本身才是我们想要求导的对象,而因为argmax本身不可导,所以引入了softmax来替代,也即我们相对 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]求导,迫不得已对 [ 0.8 , 0.1 , 0.1 ] [0.8,0.1,0.1] [0.8,0.1,0.1]求导,算是某种程度上的导数近似。而在1中的argmax本身也不是我们求导的对象,只是由于这一近似带来的补偿。而更进一步的,假设我们直接对 ( l o g ( x ) + G ) (log(\bold{x})+\bold{G}) (log(x)+G)进行求导,那么这一近似带来的误差只会更大,也让随机噪声的引入失去了意义,等价于对 l o g ( x ) log(x) log(x)求导。这也就是为什么开头的可导加了,因为我们是在对softmax求导,而不是argmax。

更新

 近日发现其实有人关注过为什么不用Gumbelmax+argmax,而是使用GumbelSoftmax+argmax来输出one-hot向量的问题,而且解释的非常清楚,可以参见3

总结

 整体而言,GumbelSoftmax通过引入了Gumble随机噪声使得输出的类别真正具有随机性,而将argmax软化为softmax则使得这一随机过程可导。

参考文献


  1. Gumbel-Softmax Trick和Gumbel分布 ↩︎ ↩︎

  2. 请问用Gumbel-softmax的时候,怎么让softmax输出的概率分布转化成one-hot向量? ↩︎

  3. 重参数化技巧(Gumbel-Softmax)
    怎样克服神经网络训练中argmax的不可导性? ↩︎

  • 18
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值