PyTorch中的梯度计算2

上篇博客已经说到,torch对梯度求导,仅保留叶子节点的梯度。这里使用FGSM进行说明。
FGSM的公式为:

在这里插入图片描述
对损失函数进行反传,得到原图x的梯度方向,在梯度方向上添加定长的扰动。结果为:
在这里插入图片描述
这里打印了x梯度及其方向的[0,0,0,0:10]。但这里存在的一个问题是,我们仅能获得原图x的梯度,原图x为叶子结点。之后原图x送入网络,中间任何层的输出,均不是叶子结点,直至最后通过全连接输出结果(叶子节点)。
那应当如何保存,数据通过网络中间层而产生的梯度呢?
知乎上的解答写的很好:
开发者说是因为当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉。因此使用register_hook保存中间变量的梯度值。简而言之,register_hook的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个全局列表,将每次的grad值添加到里面去。

import torch
from torch.autograd import Variable

grad_list = []

def print_grad(grad):
    grad_list.append(grad)

x = Variable(torch.randn(2, 1), requires_grad=True)
y = x+2
z = torch.mean(torch.pow(y, 2))
lr = 1e-3
y.register_hook(print_grad)
z.backward()
x.data -= lr*x.grad.data
可以看到计算可以分为三个部分,第一部分输入x,第二部分y=x+2,第三部分y^2再取平均,如果最后对z进行反传,最后保留的梯度应当只有x,y因为是中间变量,而没有保留梯度。

在这里插入图片描述
如果使用了register_hook,则能够保存住非叶结点y的梯度。具体来说,
假设有 x = 1 x=1 x=1, y = 2 ∗ ( x + 3 ) y=2*(x+3) y=2(x+3), z = y 3 z=y^3 z=y3, m = 3 z m=3z m=3z。则根据链式法则有:
∂ m ∂ x = ∂ m ∂ z ∗ ∂ z ∂ y ∗ ∂ y ∂ x \frac{\partial m}{\partial x}=\frac{\partial m}{\partial z}*\frac{\partial z}{\partial y}*\frac{\partial y}{\partial x} xm=zmyzxy
如果直接对m进行反向传播,得到的梯度为:
3 ∗ 3 ∗ y 2 ∗ 2 = 1152 3*3*y^2*2=1152 33y22=1152,其中 ∂ m ∂ z = 3 , ∂ z ∂ y = 3 y 2 , ∂ y ∂ x = 2 \frac{\partial m}{\partial z}=3,\frac{\partial z}{\partial y}=3y^2,\frac{\partial y}{\partial x}=2 zm=3,yz=3y2,xy=2
因为m为非叶结点,所以若想保存m的梯度,则使用x.register_hook,可以简单地认为将上面的求导法则计算到了最开始的变量x。结果为:
在这里插入图片描述
可以看到使用register_hook保存的梯度与计算的相同,同理可以将m.register_hook,z.register_hook,y.register_hook保存下来看看结果。
在这里插入图片描述
可以看出,简单理解register_hook就是求导至指定的变量。
对中间有2个变量的式子进行求导,结果也符合预期。
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值