Pytorch 可微分round函数

round函数在定义域中的导数,处处为0或者无穷,梯度无法反向传播。本文将使用autograd.function类自定义可微分的round函数,使得round前后的tensor,具有相同的梯度。

from torch.autograd import Function


class BypassRound(Function):
  @staticmethod
  def forward(ctx, inputs):
    return torch.round(inputs)

  @staticmethod
  def backward(ctx, grad_output):
    # 这里的grad_output是round之后的tensor的梯度,直接将它作为round之前tensor的梯度
    return grad_output


# Function.apply的别名
bypass_round = BypassRound.apply

# demo
z3_rounded = bypass_round(z3)

 具体原理和细节参考以下博客:

定义torch.autograd.Function的子类,自己定义某些操作,且定义反向求导函数_tsq292978891的博客-CSDN博客_saved_tensors

2022.4.7更新:更简单的方法如下

def ste_round(x):
    return torch.round(x) - x.detach() + x

torch.round(x)导数处处为0,x.detach()在计算图中,x的导数为1

因此:ste_round(x)的梯度 == x的梯度

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值