怎样克服神经网络训练中argmax的不可导性?

1. strainght through Gumbel (estimator)

令: a r g m a x ( v ) = s o f t m a x ( v ) + c ; c = a r g m a x ( v ) − s o f t m a x ( v ) , 且 为 常 数 argmax(v)=softmax(v) + c ; c=argmax(v) -softmax(v),且为常数 argmax(v)=softmax(v)+c;c=argmax(v)softmax(v),
在这里插入图片描述
在这里插入图片描述

2. stop gradient operation

在这里插入图片描述
方法:正向传播就和往常一样,反向传播时,将梯度从不可导那个点copy到 不可导点的前面的最近一个可导点。
q u a n t i z e = i n p u t + ( q u a n t i z e − i n p u t ) . d e t a c h ( ) quantize = input + (quantize - input).detach() quantize=input+(quantizeinput).detach()
在这里插入图片描述

3. 可以对argmax/argmin 这种不可导的操作直接忽视,也就是锁定

就是抛弃不可传导的位置

class ArgMax(torch.autograd.Function):
	@staticmethod
	def forward(ctx, input):
        idx = torch.argmax(input, 1)
        output = torch.zeros_like(input)
        output.scatter_(1, idx, 1) # 此处直接用1来替换argmax的位置,抛弃了此处的梯度
        return output
	
	@staticmethod
	def backward(ctx, grad_output):
        return grad_output
  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值