文章目录
1. strainght through Gumbel (estimator)
令:
a
r
g
m
a
x
(
v
)
=
s
o
f
t
m
a
x
(
v
)
+
c
;
c
=
a
r
g
m
a
x
(
v
)
−
s
o
f
t
m
a
x
(
v
)
,
且
为
常
数
argmax(v)=softmax(v) + c ; c=argmax(v) -softmax(v),且为常数
argmax(v)=softmax(v)+c;c=argmax(v)−softmax(v),且为常数
2. stop gradient operation
方法:正向传播就和往常一样,反向传播时,将梯度从不可导那个点copy到 不可导点的前面的最近一个可导点。
q
u
a
n
t
i
z
e
=
i
n
p
u
t
+
(
q
u
a
n
t
i
z
e
−
i
n
p
u
t
)
.
d
e
t
a
c
h
(
)
quantize = input + (quantize - input).detach()
quantize=input+(quantize−input).detach()
3. 可以对argmax/argmin 这种不可导的操作直接忽视,也就是锁定
就是抛弃不可传导的位置
class ArgMax(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
idx = torch.argmax(input, 1)
output = torch.zeros_like(input)
output.scatter_(1, idx, 1) # 此处直接用1来替换argmax的位置,抛弃了此处的梯度
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output