pytorch实现straight-through estimator(STE)

现在深度学习中一般我们学习的参数都是连续的,因为这样在反向传播的时候才可以对梯度进行更新。但是有的时候我们也会遇到参数是离>散的情况,这样就没有办法进行反向传播了,比如二值神经网络。本文中讲解了如何用pytorch对二值化的参数进行梯度更新的straight-through estimator算法。
Question:
STE核心的思想就是我们的参数初始化的时候就是float这样的连续值,当我们forward的时候就将原来的连续的参数映射到{-1, 1}带入到网络进行计算,这样就可以计算网络的输出。然后backward的时候直接对原来float的参数进行更新,而不是对二值化的参数更新。这样可以完成对整个网络的更新了。
首先我们对上面问题进行一下数学的讲解。
在这里插入图片描述
Example:
首先我们验证一下使用torch.sign会是参数的梯度基本上都是0:

>>> input = torch.randn(4, requires_grad = True)
>>> output = torch.sign(input)
>>> loss = output.mean()
>>> loss.backward()
>>> input
tensor([-0.8673, -0.0299, -1.1434, -0.6172], requires_grad=True)
>>> input.grad
tensor([0., 0., 0., 0.])

我们需要重写sign这个函数,就好像写一个激活函数一样。

import torch

class LBSign(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clamp_(-1, 1)
import torch
from LBSign import LBSign

if __name__ == '__main__':

    sign = LBSign.apply
    params = torch.randn(4, requires_grad = True)                                                                           
    output = sign(params)
    loss = output.mean()
    loss.backward()

测试梯度:

>>> params
tensor([-0.9143,  0.8993, -1.1235, -0.7928], requires_grad=True)
>>> params.grad
tensor([0.2500, 0.2500, 0.2500, 0.2500])

在这里插入图片描述

文章转载:https://segmentfault.com/a/1190000020993594?utm_source=tag-newest仅供参考学习,如有侵权则请联系博主。

参考文献:

  • https://segmentfault.com/a/1190000020993594?utm_source=tag-newest
  • 12
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值