全面详解gumbel softmax
Gumbel-Softmax有什么用 ?
据我所知,gumbel softmax允许模型中有从离散的分布(比如类别分布categorical distribution)中采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数,所以这让gumbel softmax在深度学习的很多领域都有应用,比如分类分割任务、采样生成类任务AIGC、强化学习、语音识别、NAS等等。如果你是主动搜索到这篇文章的,那你对gumbel softamx的应用应该有自己的理解,如果跟我一样,暂时没用到的,也可以先学起来,说不定以后的算法能用上。
2. 这个函数可导
基于前人们的知识成果积累,论文《Categorical Reparameterization with Gumbel-Softmax》的作者还真找到了解决方法,第一个问题的方法是使用Gumbel Max Trick,第二个问题的方法是把Gumbel Max Trick里的argmax换成softmax,综合起来就是Gumbel Softmax。
前置知识
累计分布函数
在介绍gumbel之前, 我们先看一下离散概率分布采样在计算机编程中是如何实现的。它的采样方法可以表示为:
从上图我们可以感受到,采样值在x=3附近比较多,密度比较高,所以相应的它的概率密度函数(PDF,Probability Density Function)在x=3处是最大的,如下图所示:
不同参数的gumbel分布的PDF函数曲线 whaosoft aiot http://143ai.com
写成代码的话,就是
import torch
# gumbel分布的CDF函数的反函数
def inverse_gumbel_cdf(u, loc, beta):
return loc - scale * torch.log(-torch.log(u))
def gumbel_distribution_sampling(n, loc=0, scale=1):
u = torch.rand(n) #使用torch.rand生成均匀分布
g = inverse_gumbel_cdf(u, loc, scale)
return g
n = 10 # 采样个数
loc = 0 # gumbel分布的位置系数,类似于高斯分布的均值
scale = 1 # gumbel分布尺度系数,类似于高斯分布的标准差
samples = gumbel_distribution_sampling(n, loc, scale)
重参数技巧(Re-parameterization Trick)
gumbel max trick里用到了重参数的思想,所以先介绍一下重参数技巧。
最原始的自编码器(AE&#