一、介绍
Gumbel-Softmax 是一种技术,用于在离散选择中引入可微分的近似。这对于需要在神经网络中进行离散采样(如分类任务或生成离散数据)而不破坏梯度计算非常有用。Gumbel-Softmax 可以看作是对经典的 Softmax 函数的一种扩展,结合了 Gumbel 噪声,用于逼近离散的 one-hot 向量,同时保持梯度的可计算性。
在许多机器学习任务中,需要从一个离散的分布中采样。例如,在强化学习或生成模型中,可能需要从一组离散的动作或词汇中进行选择。然而,直接从离散分布中采样是不连续的,这意味着无法通过反向传播来更新模型参数。
举个例子:
在分类任务中,神经网络的最后一层通常是一个全连接层,接着是一个Softmax函数,将网络输出转化为概率分布。例如,对于一个有3个类别的分类任务,网络的输出可能是:
logits=[1.2,0.9,2.5]
通过Softmax函数将其转化为概率:
probs=Softmax(logits)=[0.25,0.20,0.55]
然后,通常选择概率最大的类别作为预测结果:
prediction=argmax(probs)=2
然而,离散操作(如argmax)是不可微的,这意味着无法通过反向传播来更新参数。
假设有一个简单的损失函数,它依赖于网络的输出类别。如果直接使用 argmax来选择类别,梯度无法通过这个操作传递回网络的参数:
import torch
import torch.nn.functional as F
log