【机器学习】gumbel softmax的介绍

一、介绍

Gumbel-Softmax 是一种技术,用于在离散选择中引入可微分的近似。这对于需要在神经网络中进行离散采样(如分类任务或生成离散数据)而不破坏梯度计算非常有用。Gumbel-Softmax 可以看作是对经典的 Softmax 函数的一种扩展,结合了 Gumbel 噪声,用于逼近离散的 one-hot 向量,同时保持梯度的可计算性。

在许多机器学习任务中,需要从一个离散的分布中采样。例如,在强化学习或生成模型中,可能需要从一组离散的动作或词汇中进行选择。然而,直接从离散分布中采样是不连续的,这意味着无法通过反向传播来更新模型参数。

举个例子:

在分类任务中,神经网络的最后一层通常是一个全连接层,接着是一个Softmax函数,将网络输出转化为概率分布。例如,对于一个有3个类别的分类任务,网络的输出可能是:

logits=[1.2,0.9,2.5]

通过Softmax函数将其转化为概率:

probs=Softmax(logits)=[0.25,0.20,0.55]

然后,通常选择概率最大的类别作为预测结果:

prediction=arg⁡max⁡(probs)=2

然而,离散操作(如arg⁡max)是不可微的,这意味着无法通过反向传播来更新参数。

假设有一个简单的损失函数,它依赖于网络的输出类别。如果直接使用 arg⁡max⁡来选择类别,梯度无法通过这个操作传递回网络的参数:

import torch
import torch.nn.functional as F

log
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值