为什么 torch.nn.LogSigmoid(x) 可以避免 torch.log(torch.sigmoid(x)) 可能存在的 inf

当 x 非常接近 0 的时候,sigmoid(x) 会接近 0,这时候给计算结果取对数可能会导致数值溢出,最终输出 inf,但 torch.nn.LogSigmoid(x) 就不会导致溢出。

看了一下 torch 的实现:

void launch_log_sigmoid_forward_kernel(TensorIteratorBase& iter) {
  AT_DISPATCH_FLOATING_TYPES_AND2(
      kHalf, kBFloat16, iter.common_dtype(), "log_sigmoid_forward_cuda", [&] {
        using opmath_t = at::opmath_type<scalar_t>;
        gpu_kernel(iter, [] GPU_LAMBDA(scalar_t in_) -> scalar_t {
          const opmath_t in = in_;
          const auto min = std::min(opmath_t(0), in);
          const auto z = std::exp(-std::abs(in));
          return min - std::log1p(z);
        });
      });
}

这里使用了一个 std::log1p 函数,查阅该函数的文档:

  • This function is more precise than the expression std::log(1 + num) if num is close to zero.
  • If the argument is ±0, it is returned unmodified.

相当于这个 cpp 的库函数保证了在数值非常小的时候的精度。

当然,还有一种很有趣的写法,来源是 THNN,利用了 LogSumExp 的 trick:

void THNN_(LogSigmoid_updateOutput)(
          THNNState *state,
          THTensor *input,
          THTensor *output,
          THTensor *buffer)
{
  THTensor_(resizeAs)(output, input);
  THTensor_(resizeAs)(buffer, input);
  //Use the LogSumExp trick to make this stable against overflow
  TH_TENSOR_APPLY3(real, output, real, input, real, buffer,
    real max_elem = fmax(0, -*input_data);
    real z = exp(-max_elem) + exp(-*input_data - max_elem);
    *buffer_data = z;
    *output_data = -(max_elem + log(z));
  );
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值