Gumbel-Softmax的logits输入可以是模型的输出

如下是Gumbel-Softmax的pytorch代码实现:

def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
    # _gumbels = (-torch.empty_like(
    #     logits,
    #     memory_format=torch.legacy_contiguous_format).exponential_().log()
    #             )  # ~Gumbel(0,1)
    # more stable https://github.com/pytorch/pytorch/issues/41663

    # example logits: [batch_size, n_class] unnormalized log-probs
    gumbel_dist = torch.distributions.gumbel.Gumbel(
        torch.tensor(0., device=logits.device, dtype=logits.dtype),
        torch.tensor(1., device=logits.device, dtype=logits.dtype))
    gumbels = gumbel_dist.sample(logits.shape)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

在Gumbel-Softmax的使用中,可以使用未归一化的网络输出(即未经过 Softmax 处理也未经过 Log 处理)作为logits,这是因为Gumbel-Softmax的采样过程本质上依赖于logits的相对大小,而不绝对要求logits是概率的log值(为什么这样使用?个人认为是为了简化计算、数值稳定或提供更好的梯度性质)。以下是Gumbel-Softmax的公式,logits指的是 log ⁡ ( π ) \log(\pi) log(π) π \pi π指的是概率, g g g指的是Gumbel分布:
y i = exp ⁡ ( ( log ⁡ ( π i ) + g i ) / τ ) ∑ j = 1 k exp ⁡ ( ( log ⁡ ( π j ) + g j ) / τ ) y_i = \frac{\exp((\log(\pi_i) + g_i) / \tau)}{\sum_{j=1}^k \exp((\log(\pi_j) + g_j) / \tau)} yi=j=1kexp((log(πj)+gj)/τ)exp((log(πi)+gi)/τ)
也可以参考Gumbel-Softmax官方代码的使用示例

  • 12
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值