【机器学习】gumbel softmax的介绍

一、介绍

Gumbel-Softmax 是一种技术,用于在离散选择中引入可微分的近似。这对于需要在神经网络中进行离散采样(如分类任务或生成离散数据)而不破坏梯度计算非常有用。Gumbel-Softmax 可以看作是对经典的 Softmax 函数的一种扩展,结合了 Gumbel 噪声,用于逼近离散的 one-hot 向量,同时保持梯度的可计算性。

在许多机器学习任务中,需要从一个离散的分布中采样。例如,在强化学习或生成模型中,可能需要从一组离散的动作或词汇中进行选择。然而,直接从离散分布中采样是不连续的,这意味着无法通过反向传播来更新模型参数。

举个例子:

在分类任务中,神经网络的最后一层通常是一个全连接层,接着是一个Softmax函数,将网络输出转化为概率分布。例如,对于一个有3个类别的分类任务,网络的输出可能是:

logits=[1.2,0.9,2.5]

通过Softmax函数将其转化为概率:

probs=Softmax(logits)=[0.25,0.20,0.55]

然后,通常选择概率最大的类别作为预测结果:

prediction=arg⁡max⁡(probs)=2

然而,离散操作(如arg⁡max)是不可微的,这意味着无法通过反向传播来更新参数。

假设有一个简单的损失函数,它依赖于网络的输出类别。如果直接使用 arg⁡max⁡来选择类别,梯度无法通过这个操作传递回网络的参数:

import torch
import torch.nn.functional as F

logits = torch.tensor([1.2, 0.9, 2.5])
probs = F.softmax(logits, dim=0)
prediction = torch.argmax(probs)

# 假设有一个简单的损失函数
loss = torch.tensor(0.0)
if prediction == 2:
    loss = torch.tensor(1.0)

# 反向传播
loss.backward()  # 这里将会失败,因为 torch.argmax 是不可微的

二、Gumbel-Softmax 的作用

Gumbel-Softmax 提供了一种可微的近似,使可以在保持反向传播可行的情况下进行离散选择。通过Gumbel-Softmax,可以从离散分布中进行采样,并且保持梯度的传递。

三、Gumbel-Softmax 的工作原理

  1. 添加Gumbel噪声:首先在logits上添加Gumbel噪声,使其具有随机性。

  2. 应用Softmax:然后应用Softmax函数,使得输出接近one-hot向量,但仍然是连续的,从而可以进行梯度计算。

 
import torch
import torch.nn.functional as F

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    gumbel_noise = sample_gumbel(logits.size())
    y = logits + gumbel_noise
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False):
    y = gumbel_softmax_sample(logits, temperature)
    
    if hard:
        shape = y.size()
        _, max_idx = y.max(dim=-1, keepdim=True)
        y_hard = torch.zeros_like(y).scatter_(-1, max_idx, 1.0)
        y = (y_hard - y).detach() + y
    
    return y

logits = torch.tensor([1.2, 0.9, 2.5])
temperature = 0.5

samples = gumbel_softmax(logits, temperature, hard=True)
print(samples)

# 我们有一个简单的损失函数
target = torch.tensor([0, 0, 1])  # 假设真实类别是第三类
loss = F.cross_entropy(samples.unsqueeze(0), target.unsqueeze(0).argmax(dim=-1))

# 反向传播
loss.backward()  # 这里是可行的,因为Gumbel-Softmax是可微的

四、公式推导

假设有一个logits向量 z(通常是神经网络的输出),其第 i 个元素为 zi。

添加Gumbel噪声: 对每个logit zi​ 添加一个从Gumbel分布中采样的噪声 gi​。Gumbel噪声 gi 的公式为:

g_i=-\log(-\log(U_i))

其中 U 是从均匀分布 U(0,1) 中采样的随机变量

应用Softmax: 对添加了Gumbel噪声的logits进行缩放,并应用Softmax函数:

y_i=\frac{\exp\left(\frac{z_i+g_i}{\tau}\right)}{\sum_j\exp\left(\frac{z_j+g_j}{\tau}\right)}

其中,τ 是温度参数,控制输出的平滑程度。当τ 趋近于0时,输出接近于one-hot向量;当τ 趋近于无穷大时,输出接近均匀分布。

对于 Gumbel-Softmax,梯度的计算与 Softmax 类似,只不过在 logits 中添加了 Gumbel 噪声。具体来说,对于 Gumbel-Softmax 的输出 yi 和输入 logits zi​,其梯度计算如下:

\frac{\partial y_i}{\partial z_k}=\frac{\partial}{\partial z_k}\left(\frac{\exp\left(\frac{z_i-\log(-\log(U_i))}{\tau}\right)}{\sum_j\exp\left(\frac{z_j-\log(-\log(U_j))}{\tau}\right)}\right)

本质上还是应用 Softmax 的梯度公式,但在计算过程中包含了 Gumbel 噪声。

  • 19
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值