上篇博客已经说到,torch对梯度求导,仅保留叶子节点的梯度。这里使用FGSM进行说明。
FGSM的公式为:
对损失函数进行反传,得到原图x的梯度方向,在梯度方向上添加定长的扰动。结果为:
这里打印了x梯度及其方向的[0,0,0,0:10]。但这里存在的一个问题是,我们仅能获得原图x的梯度,原图x为叶子结点。之后原图x送入网络,中间任何层的输出,均不是叶子结点,直至最后通过全连接输出结果(叶子节点)。
那应当如何保存,数据通过网络中间层而产生的梯度呢?
知乎上的解答写的很好:
开发者说是因为当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉。因此使用register_hook保存中间变量的梯度值。简而言之,register_hook的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个全局列表,将每次的grad值添加到里面去。
import torch
from torch.autograd import Variable
grad_list = []
def print_grad(grad):
grad_list.append(grad)
x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data
可以看到计算可以分为三个部分,第一部分输入x,第二部分y=x+2,第三部分y^2再取平均,如果最后对z进行反传,最后保留的梯度应当只有x,y因为是中间变量,而没有保留梯度。
如果使用了register_hook,则能够保存住非叶结点y的梯度。具体来说,
假设有
x
=
1
x=1
x=1,
y
=
2
∗
(
x
+
3
)
y=2*(x+3)
y=2∗(x+3),
z
=
y
3
z=y^3
z=y3,
m
=
3
z
m=3z
m=3z。则根据链式法则有:
∂
m
∂
x
=
∂
m
∂
z
∗
∂
z
∂
y
∗
∂
y
∂
x
\frac{\partial m}{\partial x}=\frac{\partial m}{\partial z}*\frac{\partial z}{\partial y}*\frac{\partial y}{\partial x}
∂x∂m=∂z∂m∗∂y∂z∗∂x∂y。
如果直接对m进行反向传播,得到的梯度为:
3
∗
3
∗
y
2
∗
2
=
1152
3*3*y^2*2=1152
3∗3∗y2∗2=1152,其中
∂
m
∂
z
=
3
,
∂
z
∂
y
=
3
y
2
,
∂
y
∂
x
=
2
\frac{\partial m}{\partial z}=3,\frac{\partial z}{\partial y}=3y^2,\frac{\partial y}{\partial x}=2
∂z∂m=3,∂y∂z=3y2,∂x∂y=2。
因为m为非叶结点,所以若想保存m的梯度,则使用x.register_hook,可以简单地认为将上面的求导法则计算到了最开始的变量x。结果为:
可以看到使用register_hook保存的梯度与计算的相同,同理可以将m.register_hook,z.register_hook,y.register_hook保存下来看看结果。
可以看出,简单理解register_hook就是求导至指定的变量。
对中间有2个变量的式子进行求导,结果也符合预期。