全面详解gumbel softmax
Gumbel-Softmax有什么用 ?
据我所知,gumbel softmax允许模型中有从离散的分布(比如类别分布categorical distribution)中采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数,所以这让gumbel softmax在深度学习的很多领域都有应用,比如分类分割任务、采样生成类任务AIGC、强化学习、语音识别、NAS等等。如果你是主动搜索到这篇文章的,那你对gumbel softamx的应用应该有自己的理解,如果跟我一样,暂时没用到的,也可以先学起来,说不定以后的算法能用上。
2. 这个函数可导
基于前人们的知识成果积累,论文《Categorical Reparameterization with Gumbel-Softmax》的作者还真找到了解决方法,第一个问题的方法是使用Gumbel Max Trick,第二个问题的方法是把Gumbel Max Trick里的argmax换成softmax,综合起来就是Gumbel Softmax。
前置知识
累计分布函数
在介绍gumbel之前, 我们先看一下离散概率分布采样在计算机编程中是如何实现的。它的采样方法可以表示为:
从上图我们可以感受到,采样值在x=3附近比较多,密度比较高,所以相应的它的概率密度函数(PDF,Probability Density Function)在x=3处是最大的,如下图所示:
不同参数的gumbel分布的PDF函数曲线
写成代码的话,就是
重参数技巧(Re-parameterization Trick)
gumbel max trick里用到了重参数的思想,所以先介绍一下重参数技巧。
最原始的自编码器(AE,Auto Encoder,自编码器就是输入一张图片,编码成一个隐向量,再把这个隐向量重建回原图的样子)长这样:
左右两边是端到端的输入输出网络,中间的绿色是提取的特征向量,这是一种直接从图片提取特征并将特征直接重建回去的方式,很符合直觉。
而VAE(Variational Auto Encoder)长这样:
VAE的想法是不直接用编码器去提取特征向量(也就是隐向量),而是提取这张图像的分布特征,比如说均值和标准差,也就是把绿色的特征向量替换为分布的参数向量。然后需要解码图像的时候,就用编码器输出的分布参数采样得到特征向量样本,用这个样本去重建图像。
以上就是重参数技巧在图像生成领域的一个案例,可以表示为下图所示:
Gumbel-Max Trick
Gumbel-Max Trick也是使用了重参数技巧把采样过程分成了确定性的部分和随机性的部分,我们会计算所有类别的log分布概率(确定性的部分),类似于上面例子中的均值,然后加上一些噪音(随机性的部分),上面的例子中,噪音是标准高斯分布,而这里噪音是标准gumbel分布。在我们把采样过程的确定性部分和随机性部分结合起来之后,我们在此基础上再用一个argmax来找到具有最大概率的类别。自此可见,Gumbel-Max Trick由使用gumbel分布的Re-parameterization Trick和argmax组成而成,正如它的名字一样。
用公式表示的话就是:
下面我们用一个例子和代码来验证一下这个极值分布的规律。假设你每天都会喝很多次水(比如100次),每次喝水的量服从正态分布N(μ,σ2)(其实也有点不合理,毕竟喝水的多少不能取为负值,不过无伤大雅能理解就好,假设均值为5),那么每天100次喝水里总会有一个最大值,这个最大值服从的分布就是Gumbel分布。
上面的例子中极值是采样自高斯分布,且是连续分布,那如果极值是采样自一个离散的类别分布呢,下面我们再用代码来验证一下。
如下代码定义了一个7类别的多项分布,每个类别的概率如下图
可以看到,两个分布几乎一模一样。
Gumbel Softmax
最后总结一下Gumbel-Softmax Trick的步骤:
其实我觉得只是把argmax替换成softmax还不够,应该是替换成soft argmax,这一点有待以后如果工作中遇到了再验证。也欢迎有实践经验的朋友告知。
pytorch相关函数说明
pytorch 提供的torch.nn.functional.gumbel_softmax
api:https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html#torch.nn.functional.gumbel_softmax
视频讲解:《Gumbel Softmax补充说明》
实现的源代码:https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
我这里对实现的源代码做一些说明:
可以看到两个分布十分近似,所以pytorch源代码里使用指数分布采样是没问题的。
本文部分内容参考或摘抄自:
《gumber分布的维基百科》
《Gumbel-Softmax 完全解析》
《Gumbel-Softmax Trick和Gumbel分布 》
《The Gumbel-Softmax Distribution》
《Gumbel softmax trick (快速理解附代码)》
《漫谈重参数:从正态分布到Gumbel Softmax》