Gumbel Softmax

Gumbel-Softmax是一种在深度学习中使离散选择过程可微分的技术,它结合了GumbelMaxTrick和softmax函数。通过这种方式,模型能在反向传播时更新参数,应用包括分类、生成任务、强化学习等。文章介绍了Gumbel分布的采样方法、重参数技巧以及Gumbel-Softmax的实现和作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

全面详解gumbel softmax

Gumbel-Softmax有什么用 ?

据我所知,gumbel softmax允许模型中有从离散的分布(比如类别分布categorical distribution)中采样的这个过程变得可微,从而允许反向传播时可以用梯度更新模型参数,所以这让gumbel softmax在深度学习的很多领域都有应用,比如分类分割任务、采样生成类任务AIGC、强化学习、语音识别、NAS等等。如果你是主动搜索到这篇文章的,那你对gumbel softamx的应用应该有自己的理解,如果跟我一样,暂时没用到的,也可以先学起来,说不定以后的算法能用上。

2. 这个函数可导

基于前人们的知识成果积累,论文《Categorical Reparameterization with Gumbel-Softmax》的作者还真找到了解决方法,第一个问题的方法是使用Gumbel Max Trick,第二个问题的方法是把Gumbel Max Trick里的argmax换成softmax,综合起来就是Gumbel Softmax。

前置知识

累计分布函数

在介绍gumbel之前, 我们先看一下离散概率分布采样在计算机编程中是如何实现的。它的采样方法可以表示为:

从上图我们可以感受到,采样值在x=3附近比较多,密度比较高,所以相应的它的概率密度函数(PDF,Probability Density Function)在x=3处是最大的,如下图所示: 

不同参数的gumbel分布的PDF函数曲线 whaosoft aiot http://143ai.com 

写成代码的话,就是

import torch

# gumbel分布的CDF函数的反函数
def inverse_gumbel_cdf(u, loc, beta):
    return loc - scale * torch.log(-torch.log(u))

def gumbel_distribution_sampling(n, loc=0, scale=1):
    u = torch.rand(n)  #使用torch.rand生成均匀分布
    g = inverse_gumbel_cdf(u, loc, scale)
    return g

n = 10  # 采样个数
loc = 0 # gumbel分布的位置系数,类似于高斯分布的均值
scale = 1 # gumbel分布尺度系数,类似于高斯分布的标准差

samples = gumbel_distribution_sampling(n, loc, scale)

重参数技巧(Re-parameterization Trick)

gumbel max trick里用到了重参数的思想,所以先介绍一下重参数技巧。

最原始的自编码器(AE&#

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值