点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
作者丨wwdok@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/633431594
编辑丨极市平台
极市导读
全面详解gumbel softmax。
前言
我在学习《CLIP 改进工作串讲(上)【论文精读·42】》的过程中,听到朱老师讲到了GroupViT中用到了gumbel softmax(相关源代码),于是我带着好奇心试图想去了解gumbel softmax是什么,最后我把我的理解写成这篇文章,但是目前我在工作中还没用到gumbel softmax,所以如果有说得不对的地方,欢迎指正。
Gumbel-Softmax有什么用 ?
据我所知,gumbel softmax允许模型中有从离散的分布(比如类别分布categorical distribution)中采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数,所以这让gumbel softmax在深度学习的很多领域都有应用,比如分类分割任务、采样生成类任务AIGC、强化学习、语音识别、NAS等等。如果你是主动搜索到这篇文章的,那你对gumbel softamx的应用应该有自己的理解,如果跟我一样,暂时没用到的,也可以先学起来,说不定以后的算法能用上。
我们还是通过一个简单的例子来切入。假设我们有一个神经网络模型, 模型中间某一层的输出是 个类别的概率, 它的概率分布可以表示为:
, 且 。紧接着我们使用argmax挑出概率最大的那个类别索引 , 然后用它继续前向传播, 这么做在前向传播时没有问题, 但在反向传播时, 梯度回传到这里, 就会卡住, 因为我们生成 的公式argmax无法求导, 梯度无法再反向传播下去。于是有人就想, 能不能构造一个生成 的 公式, 用它来取代argmax, 这个公式需要具备以下特点:
1. 以 为参数, 且这个公式输出某个采样值的概率 跟原来的概率 一样(或者说十分近似)
2. 这个函数可导
基于前人们的知识成果积累,论文《Categorical Reparameterization with Gumbel-Softmax》的作者还真找到了解决方法,第一个问题的方法是使用Gumbel Max Trick,第二个问题的方法是把Gumbel Max Trick里的argmax换成softmax,综合起来就是Gumbel Softmax。
前置知识
累计分布函数
在介绍gumbel之前, 我们先看一下离散概率分布采样在计算机编程中是如何实现的。它的采样方法可以表示为:
, 其中 是类别的下标, 随机变量 服从均匀分布 。这个采样方法实际上是很巧妙的, 首先它从均匀分布 中采样出一个随机值 , 然后它将概率分布从前往后不断累加起来, 当类加到 时超过了随机值 , 那么这一次随机采样过程, 就被采样为第 类。 在做的事情是累计概率, 这和累计分布函数 (CDF, Cumulative distribution funtion) 的作用差不多, 关于这一点的更多讲解可以看视频《Gumbel-softmax 中文解读 Categorical Reparameterization with Gumbel-Softmax》 02:12~04:25。也就是说, 利用一个均匀分布和某个分布的CDF, 就可以实现某个分布的采样。上面的例子是离散的形式, 下面再举个连续的例子, 假设我们要从gumbel分布中采样, gumbel分布的CDF公式可以见《gumber分布的维基百科》, 首先在[0,1]之间均匀采样, 代表均匀分布 采样出来的值, 假设采样9个值, 那就是 , 把这 9 个值作为y值代入 gumbel 的CDF函数, 求出x, 这个x就是采样得到的值, 这里使用下图的青色线演示:

不同参数的gumbel分布的CDF函数曲线
从上图我们可以感受到,采样值在x=3附近比较多,密度比较高,所以相应的它的概率密度函数(PDF,Probability Density Function)在x=3处是最大的,如下图所示:

不同参数的gumbel分布的PDF函数曲线
写成代码的话,就是
import torch
# gumbel分布的CDF函数的反函数
def inverse_gumbel_cdf(u, loc, beta):
return loc - scale * torch.log(-torch.log(u))
def gumbel_distribution_sampling(n, loc=0, scale=1):
u = torch.rand(n) #使用torch.rand生成均匀分布
g = inverse_gumbel_cdf(u, loc, scale)
return g
n = 10 # 采样个数
loc = 0 # gumbel分布的位置系数,类似于高斯分布的均值
scale = 1 # gumbel分布尺度系数,类似于高斯分布的标准差
samples = gumbel_distribution_sampling(n, loc, scale)
gumbel max trick公式里就用到了这个采样思想,即先用均匀分布采样出一个随机值,然后把这个值带入到gumbel分布的CDF函数的逆函数(inverse function,或者称为反函数)得到采样值。另外值得一说的是,gumbel max trick里使用的gumbel分布是标准gumbel分布,即 (通用的gumbel分布CDF和PDF公式见维基百科),标准gumbel分布的CDF是 ,那它的逆函数就是 。
重参数技巧(Re-parameterization Trick)
gumbel max trick里用到了重参数的思想,所以先介绍一下重参数技巧。
最原始的自编码器(AE,Auto Encoder,自编码器就是输入一张图片,编码成一个隐向量,再把这个隐向量重建回原图的样子)长这样:

左右两边是端到端的输入输出网络,中间的绿色是提取的特征向量,这是一种直接从图片提取特征并将特征直接重建回去的方式,很符合直觉。
而VAE(Variational Auto Encoder)长这样:

VAE的想法是不直接用编码器去提取特征向量(也就是隐向量),而是提取这张图像的分布特征,比如说均值和标准差,也就是把绿色的特征向量替换为分布的参数向量。然后需要解码图像的时候,就用编码器输出的分布参数采样得到特征向量样本,用这个样本去重建图像。
以上就是重参数技巧在图像生成领域的一个案例,可以表示为下图所示:

那为什么要这么做呢? 这个跟梯度反向传播有关。首先看上图左侧, 假设图中的 和 表示VAE 中的均值和标准差向量, 它们是确定性的节点, 而输出的样本 是带有随机性的节点, 在梯度反向传播时, 就会卡在z节点(这里梯度回传卡在这里不是因为不可导, 而是因为随机性)。重参数就是把带有随机性的z变成确定性的节点, 同时把随机性转嫁给另一个输入节点 日 。例如, 这里用正态分布采样, 原本从均值为 和标准差为 的正态分布 中采样得到 , 将其转化成从标准正态分布 中采样得到 , 再通过重参数技巧计算得到 。这样一来, 采样的过程移出了梯度反向传播的路径, 计算图里的参数 (均值x和标准差 ) 就可以用梯度更新了, 而新加的 的输入分支不做更新, 只当成一个没有权重变化的输入。
用博客《The Gumbel-Softmax Distribution》的说法再复述一遍, 重参数就是把原来完全随机的 节点分成了确定的节点和随机的节点两部分:
where , 如下图所示(意思跟上图差不多):

Gumbel-Max Trick
Gumbel-Max Trick也是使用了重参数技巧把采样过程分成了确定性的部分和随机性的部分,我们会计算所有类别的log分布概率(确定性的部分),类似于上面例子中的均值,然后加上一些噪音(随机性的部分),上面的例子中,噪音是标准高斯分布,而这里噪音是标准gumbel分布。在我们把采样过程的确定性部分和随机性部分结合起来之后,我们在此基础上再用一个argmax来找到具有最大概率的类别。自此可见,Gumbel-Max Trick由使用gumbel分布的Re-parameterization Trick和argmax组成而成,正如它的名字一样。
用公式表示的话就是:
其中 , 这一项就是从 gumbel 分布采样得到的噪声, 目的是使得 的返回结果不固定, 它是标准gumbel分布的CDF的逆函数, 上面已经用它举过例子了。用一个例子说明一下Gumbel-Max Trick做的事情:[0.1, 0.7, 0.2] -> [log(0.1) + gumbel_noise, gumbel_noise, gumbel_noise]
那为什么随机部分要用gumbel分布而不是常见的高斯分布呢? 这是因为gumbel分布是专门用来建模从其他分布 (比如高斯分布) 采样出来的极值形成的分布, 而我们这里“使用argmax挑出概率最大的那个类别索引 "就属于取极值的操作, 所以它属于极值分布。在此之前, 你可能想都不敢想, 极值形成的分布竞然是有规律的, 可的确就是有这么神奇的存在, 这就是数学的魅力所在, 但是要加个条件, 就是极值是采样自某一个指数族的概率分布, 比如高斯分布。
下面我们用一个例子和代码来验证一下这个极值分布的规律。假设你每天都会喝很多次水(比如100次),每次喝水的量服从正态分布N(μ,σ2)(其实也有点不合理,毕竟喝水的多少不能取为负值,不过无伤大雅能理解就好,假设均值为5),那么每天100次喝水里总会有一个最大值,这个最大值服从的分布就是Gumbel分布。
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
mean_hunger = 5
samples_per_day = 100
n_days = 10000
samples = np.random.normal(loc=mean_hunger, scale=1.0, size=(n_days, samples_per_day))
daily_maxes = np.max(samples, axis=1)
# gumbel的通用PDF公式见维基百科
def gumbel_pdf(prob,loc,scale):
z = (prob-loc)/scale
return np.exp(-z-np.exp(-z))/scale
def plot_maxes(daily_maxes):
probs,bins,_ = plt.hist(daily_maxes,density=True,bins=100)
print(f"==>> probs: {probs}") # 每个bin的概率
print(f"==>> bins: {bins}") # 即横坐标的tick值
print(f"==>> _: {_}")
print(f"==>> probs.shape: {probs.shape}") # (100,)
print(f"==>> bins.shape: {bins.shape}") # (101,)
plt.xlabel('Volume')
plt.ylabel('Probability of Volume being daily maximum')
# 基于直方图,下面拟合出它的曲线。
(fitted_loc, fitted_scale), _ = curve_fit(gumbel_pdf, bins[:-1],probs)
print(f"==>> fitted_loc: {fitted_loc}")
print(f"==>> fitted_scale: {fitted_scale}")
#curve_fit用于曲线拟合,doc:https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html
#比如我们要拟合函数y=ax+b里的参数a和b,a和b确定了,这个函数就确定了,为了拟合这个函数,我们需要给curve_fit()提供待拟合函数的输入和输出样本
#所以curve_fit()的三个入参是:1.待拟合的函数(要求该函数的第一个入参是输入,后面的入参是要拟合的函数的参数)、2.样本输入、3.样本输出
#返回的是拟合好的参数,打包在元组里
# 其他教程:https://blog.csdn.net/guduruyu/article/details/70313176
plt.plot(bins, gumbel_pdf(bins, fitted_loc, fitted_scale))
plt.figure()
plot_maxes(daily_maxes)

上面的例子中极值是采样自高斯分布,且是连续分布,那如果极值是采样自一个离散的类别分布呢,下面我们再用代码来验证一下。
如下代码定义了一个7类别的多项分布,每个类别的概率如下图
n_cats = 7
cats = np.arange(n_cats)
probs = np.random.randint(low=1, high=20, size=n_cats)
probs = probs / sum(probs)
logits = np.log(probs)
def plot_probs():
plt.bar(cats, probs)
plt.xlabel("Category")
plt.ylabel("Probability")
plt.figure()
plot_probs()

接下来我们将用代码演示为什么 里加的噪音 得是gumbel分布, 而不能是高斯分布或均匀分布
def sample_gumbel(logits):
noise = np.random.gumbel(size=len(logits))
sample = np.argmax(logits+noise)
return sample
gumbel_samples = [sample_gumbel(logits) for _ in range(n_samples)]
def sample_uniform(logits):
noise = np.random.uniform(size=len(logits))
sample = np.argmax(logits+noise)
return sample
uniform_samples = [sample_uniform(logits) for _ in range(n_samples)]
def sample_normal(logits):
noise = np.random.normal(size=len(logits))
sample = np.argmax(logits+noise)
return sample
normal_samples = [sample_normal(logits) for _ in range(n_samples)]
plt.figure(figsize=(10,4))
plt.subplot(1,4,1)
plot_probs()
plt.subplot(1,4,2)
gumbel_estd_probs = plot_estimated_probs(gumbel_samples,'Gumbel ')
plt.subplot(1,4,3)
normal_estd_probs = plot_estimated_probs(normal_samples,'Normal ')
plt.subplot(1,4,4)
uniform_estd_probs = plot_estimated_probs(uniform_samples,'Uniform ')
plt.tight_layout()
print('Original probabilities:\t\t',end='')
print_probs(probs)
print('Gumbel Estimated probabilities:\t',end='')
print_probs(gumbel_estd_probs)
print('Normal Estimated probabilities:\t',end='')
print_probs(normal_estd_probs)
print('Uniform Estimated probabilities:',end='')
print_probs(uniform_estd_probs)

可以看到, 只有加的噪声是gumbel分布, 最后的概率分布才跟原来的分布差不多, 加高斯分布和均匀分布的噪声的概率分布跟原来的概率分布明显差别很大。由此可见, 采样得到的概率分布跟原来的概率分布 几乎一样。 代码中使用 np. random. gumbel ( ) 来从gumbel分布中采样, 而这可以通过 等效实现, 这里为了简单起见, 使用标准gumbel分布 ( ), 我们再用代码验证一下, :
n_samples = 100000
numpy_gumbel = np.random.gumbel(loc=0, scale=1.0, size=n_samples)
manual_gumbel = -np.log(-np.log(np.random.uniform(size=n_samples)))
plt.figure()
plt.subplot(1, 2, 1)
plt.hist(numpy_gumbel, bins=50)
plt.ylabel("Probability")
plt.xlabel("numpy Gumbel")
plt.subplot(1, 2, 2)
plt.hist(manual_gumbel, bins=50)
plt.xlabel("Gumbel from uniform noise")

可以看到,两个分布几乎一模一样。
话说回来, 那为什么通过 采样出来的概率分布跟原来的概率分布 一样呢? 证明过程见方法一(https://kexue.fm/archives/6705)或方法二(https://www.cnblogs.com/initial-h/p/9468974.html)。这里补充说明一下方法一的证明过程中的红框部分。

这句话的意思是因为 是均匀分布 , 也就是说 可能是 0 和 1 之间的任意值, 所以我们要考虑所有情况的话, 就要对 在0和1区间做积分, 借助Geogebra工具, 验证了这个积分确实恒等于 (或者说近似) p1, 如下动图所示:
Gumbel Softmax
Gumbel-Max Trick中含有不可导的部分 argmax,这个问题可以用可导的 softmax 函数替换它来解决, 最终新的概率分布 的公式为:
其中, 是一个温度系数, 跟知识蒸馏里那个一样, 越小 , 整个 softmax 越逼近 argmax, 越大, 越接近于均匀分布。论文 《Categorical Reparameterization with GumbelSoftmax》中就有一张图描述了这个特性:

最后总结一下Gumbel-Softmax Trick的步骤:
对于网络输出的一个 维向量 ,生成 个服从均匀分布 的独立样本
通过 计算得到
对应相加得到新的值向量
通过 softmax函数计算各个类别的概率大小, 其中 是温度参数:
其实我觉得只是把argmax替换成softmax还不够,应该是替换成soft argmax,这一点有待以后如果工作中遇到了再验证。也欢迎有实践经验的朋友告知。
pytorch相关函数说明
pytorch 提供的torch.nn.functional.gumbel_softmax
api:https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html#torch.nn.functional.gumbel_softmax
视频讲解:《Gumbel Softmax补充说明》
实现的源代码:https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
我这里对实现的源代码做一些说明:
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
说明:代码中的logits已经经过了log()处理, 相当于公式里的 , 还有在计算gumbels时, 源 码里使用了 exponential_(), 它的文档里说了, 这代表的是从指数分布 中采样, 这跟我们前面公式里的 不太一样, 它没有从均匀分布 里采样, 而是从指数分布 里采样, 但是其实两者殊途同归, 因为 的逆函数是 (其实代码里的log一般是 ), 所以其实 代表的就是从指数分布 里采样, 我们同样可以用代码来验证一下:
import numpy as np
import matplotlib.pyplot as plt
n_samples = 100000
numpy_exponential = np.random.exponential(size=n_samples)
manual_exponential = -np.log(np.random.uniform(size=n_samples))
plt.figure()
plt.subplot(1, 2, 1)
plt.hist(numpy_exponential, bins=50)
plt.ylabel("Probability")
plt.xlabel("numpy exponential")
plt.subplot(1, 2, 2)
plt.hist(manual_exponential, bins=50)
plt.xlabel("Exponential from uniform noise")

可以看到两个分布十分近似,所以pytorch源代码里使用指数分布采样是没问题的。
本文部分内容参考或摘抄自:
《gumber分布的维基百科》
《Gumbel-Softmax 完全解析》
《Gumbel-Softmax Trick和Gumbel分布 》
《The Gumbel-Softmax Distribution》
《Gumbel softmax trick (快速理解附代码)》
《漫谈重参数:从正态分布到Gumbel Softmax》
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~