pytorch的自动求导和hook技术简介

在Pytorch的计算图中,由用户手动创建的tensor称为叶子节点,默认不计算梯度,即requires_grad=False。由函数(Function)计算得到的tensor是非叶子节点,requires_grad由参与计算的tensor决定。在整个计算图中,只要有一个节点的requires_grad=True,则所有依赖该节点的节点的requires_grad均为True。

tensor.grad_fn指向创建tensor的Function;叶子节点的grad_fn为None

x = torch.tensor([2], dtype=float, requires_grad=True)
y = torch.tensor([6], dtype=float, requires_grad=True)
y_ = y / 2
z = x ** 2 + y_ * 2
print(x)  # tensor([2.], dtype=torch.float64, requires_grad=True)
print(y)  # tensor([6.], dtype=torch.float64, requires_grad=True)
print(y_) # tensor([3.], dtype=torch.float64, grad_fn=<DivBackward0>)
print(z)  # tensor([10.], dtype=torch.float64, grad_fn=<AddBackward0>)

print(x.requires_grad, y.requires_grad, y_.requires_grad, z.requires_grad)  # True True True True
print(x.is_leaf, y.is_leaf, y_.is_leaf, z.is_leaf)  # True True False False
print(x.grad_fn, y.grad_fn, y_.grad_fn, z.grad_fn)  
# None None  <DivBackward0 object at 0x0000025FF3080EF0> <AddBackward0 object at 0x0000025FF3080F98>

grad_fn.next_functions是上级节点的grad_fn:

print(z.grad_fn.next_functions)  
# ((<PowBackward0 object at 0x0000025FF3080F98>, 0), (<MulBackward0 object at 0x0000025FF315C128>, 0))
print(z.grad_fn.next_functions[0][0].next_functions)  
# ((<AccumulateGrad object at 0x0000025FF3080EF0>, 0),)
print(z.grad_fn.next_functions[1][0].next_functions)  
# ((<DivBackward0 object at 0x0000025FF3080EF0>, 0), (None, 0))
print(z.grad_fn.next_functions[1][0].next_functions[0][0].next_functions)  
# ((<AccumulateGrad object at 0x0000025FF3080F98>, 0), (None, 0))

z.backward()  # retain_graph=False
print(x.grad, y.grad, y_.grad, z.grad)  
# tensor([4.], dtype=torch.float64) tensor([1.], dtype=torch.float64) None None
# y_是中间节点,故其梯度不被保存,为None

反向传播的中间缓存会被清空,若需要进行多次反向传播则指定retain_graph=True来保存这些缓存

# 接上
z.backward()
# RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

若不再需要某节点的梯度,则将其从当前计算图中分离:

x_ = x.detach()  # 返回一个不计算梯度的新tensor,该tensor与原tensor共享内存
print(x_.grad)  # None
z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)  # x和y的梯度均继续累积
# tensor([8.], dtype=torch.float64) tensor([2.], dtype=torch.float64) None None
x.detach_()  # 将原tensor从计算图中分离
z.backward()
print(x.grad, y.grad, y_.grad, z.grad)  # x的梯度不再累积,y的梯度继续累积
# tensor([8.], dtype=torch.float64) tensor([3.], dtype=torch.float64) None None

多次反向传播时,每个tensor的梯度都是累加的,故此时要得到正确的梯度必须先清空历史梯度:

import torch

x = torch.tensor([2], dtype=float, requires_grad=True)
y = torch.tensor([6], dtype=float, requires_grad=True)
y_ = y / 2
z = x ** 2 + y_ * 2

z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)
# tensor([4.], dtype=torch.float64) tensor([1.], dtype=torch.float64) None None

y_.retain_grad() # 保留指定的非叶节点的梯度;必须定义在backward()之前
z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)
# tensor([8.], dtype=torch.float64) tensor([2.], dtype=torch.float64) tensor([2.], dtype=torch.float64) None
# 此时梯度被累积

x.grad.zero_()
z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)
# tensor([4.], dtype=torch.float64) tensor([3.], dtype=torch.float64) tensor([4.], dtype=torch.float64) None
# x的梯度事先被清零,计算正确;其他节点的梯度继续累积

非叶子节点的梯度计算完之后即被清空,可以使用autograd.grad或retain_grad或hook技术获取非叶子节点的值。

torch.autograd.grad用法如下,给定具有依赖关系的两个节点就能计算出相应的梯度。可根据需要指定retain_graph的值。

print(torch.autograd.grad(z, [x, y, y_], retain_graph=True))
# (tensor([4.], dtype=torch.float64), tensor([1.], dtype=torch.float64), tensor([2.], dtype=torch.float64))

retain_grad在之前的代码中已经使用过,具体用法不再赘述。下面主要介绍hook技术(通过阅读源码可以发现,retain_grad也是借助hook技术来实现的)。

利用hook,我们可以在不改变网络输入输出的结构,方便地获取、改变网络中间节点的值和梯度。hook分为针对tensor的hook和针对module的hook。先介绍针对tensor的hook,用法为tensor.register_hook(grad_fn),grad_fn为自定义函数,用于对tensor的梯度做相关处理,输入为tensor的梯度,输出为一个tensor(改变回传的梯度值)或None(不改变回传的梯度值)。也可在一个节点处添加多个hook,按照定义的顺序依次执行。register_hook返回一个RemovableHandle类的对象,执行该类的remove方法可将计算图中相应的hook移除。

h1 = y_.register_hook(lambda grad: print(grad))
h2 = y_.register_hook(lambda grad: 2*grad)
z.backward()
print(y.grad, y_.grad)
# tensor([2.], dtype=torch.float64)
# tensor([5.], dtype=torch.float64) tensor([6.], dtype=torch.float64)

h1.remove()  # hook函数应在使用后及时移除,以免增加计算量
h2.remove()

y_节点的梯度计算完成后,依次执行两个hook函数。第一个函数打印该节点的梯度值,结果为2;第二个函数将该节点的梯度值*2后再反向传播,即该节点传给y节点的梯度值为4,则y节点的梯度值由1变为2,在累加到之前的结果上,故y.grad的值变为5。需要注意的是,hook函数改变的是节点传出的梯度值大小,其梯度本身并未改变,仍为2,故累加到之前的结果上,y_.grad的值为6。

由上述分析可知,针对tensor的hook的本质是,在节点梯度计算完成后且向后传播前,都调用一次grad_fn函数,在grad_fn中对梯度进行处理,将处理后的梯度向后传播。

将上述代码综合一下:

import torch

x = torch.tensor([2], dtype=float, requires_grad=True)
y = torch.tensor([6], dtype=float, requires_grad=True)
y_ = y / 2
z = x ** 2 + y_ * 2

z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)
# tensor([4.], dtype=torch.float64) tensor([1.], dtype=torch.float64) None None

print(torch.autograd.grad(z, [x, y, y_], retain_graph=True))
# (tensor([4.], dtype=torch.float64), tensor([1.], dtype=torch.float64), tensor([2.], dtype=torch.float64))

y_.retain_grad() # 保留指定的非叶节点的梯度;必须定义在backward()之前
z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)
# tensor([8.], dtype=torch.float64) tensor([2.], dtype=torch.float64) tensor([2.], dtype=torch.float64) None
# 此时梯度被累积

x.grad.zero_()
z.backward(retain_graph=True)
print(x.grad, y.grad, y_.grad, z.grad)
# tensor([4.], dtype=torch.float64) tensor([3.], dtype=torch.float64) tensor([4.], dtype=torch.float64) None
# x的梯度事先被清零,计算正确;其他节点的梯度继续累积

h1 = y_.register_hook(lambda grad: print(grad))
h2 = y_.register_hook(lambda grad: 2*grad)
z.backward()
print(y.grad, y_.grad)
# tensor([2.], dtype=torch.float64)
# tensor([5.], dtype=torch.float64) tensor([6.], dtype=torch.float64)

h1.remove()  # hook函数应在使用后及时移除,以免增加计算量
h2.remove()

观察上述代码的输出结果,对比三种方法:

torch.autograd.grad和hook技术都是需要在进行一次反向传播来计算需要的梯度值,能够得到准确的梯度值。而retain_grad是保留之后所有反向传播过程中该节点的梯度缓存,若要得到准确的梯度值,则要先清空缓存。除此之外,hook技术能够改变节点回传的梯度值,作用更加广泛,推荐使用。

然而,在一个较大的网络中,我们很难获得中间变量显式的变量名,进而对其执行register_hook操作。此时要获得这些中间变量的梯度就要使用针对module的register_forward_hook和register_backward_hook来分别获得前向和反向传播时,中间变量输入和输出的 feature/gradient。
register_forward_hook(hook_fn)可获取module的feature前向传播时的输入和输出,其中hook_fn的函数签名为:

hook_fn(module, input, output) -> Tensor or None

register_backward_hook(hook_fn)可获取module的gradient反向传播时的输入和输出(注意此时的输入输出的方向是前向传播的方向),其中hook_fn的函数签名为:

hook_fn(module, grad_input, grad_output) -> Tensor or None

使用实例如下:

import torch
import torch.nn as nn

class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()

        self.fc1 = nn.Linear(1, 2, bias=False)
        self.fc2 = nn.Linear(2, 1, bias=False)
        self._initialize()

    def _initialize(self):
        self.fc1.weight = nn.Parameter(
            torch.tensor([[2.], 
                          [2.]]))
        self.fc2.weight = nn.Parameter(torch.tensor([2., 2.]))

    def forward(self, x):
        return self.fc2(self.fc1(x))


def forward_hook_fn(module, input, output):
    print(module)
    print('feature input:', input)
    print('feature output:', output)


def backward_hook_fn(module, grad_input, grad_output):
    print(module)
    print('gradient input:', grad_input)
    print('gradient output:', grad_output)


model = net()
modules = model.children()
for m in modules:
    m.register_forward_hook(forward_hook_fn)    # 作用于输出计算完成后,向前传播前
    m.register_backward_hook(backward_hook_fn)  # 作用与梯度计算完成后,向后传播前

x = torch.tensor([2.], requires_grad=True)
out = model(x)
out.backward()
print(x.grad)
# Linear(in_features=1, out_features=2, bias=False)
# feature input: (tensor([2.], requires_grad=True),)
# feature output: tensor([4., 4.], grad_fn=<SqueezeBackward3>)
# Linear(in_features=2, out_features=1, bias=False)
# feature input: (tensor([4., 4.], grad_fn=<SqueezeBackward3>),)
# feature output: tensor(16., grad_fn=<DotBackward>)
# Linear(in_features=2, out_features=1, bias=False)
# gradient input: (tensor([2., 2.]), tensor([4., 4.]))
# gradient output: (tensor(1.),)
# Linear(in_features=1, out_features=2, bias=False)
# gradient input: (tensor([[2., 2.]]),)
# gradient output: (tensor([2., 2.]),)
# tensor([8.])

当这两个函数均作用与包含子module的复合module时,结果如下:

model.register_forward_hook(forward_hook_fn)
model.register_backward_hook(backward_hook_fn)
# net(
#   (fc1): Linear(in_features=1, out_features=2, bias=False)
#   (fc2): Linear(in_features=2, out_features=1, bias=False)
# )
# feature input: (tensor([2.], requires_grad=True),)
# feature output: tensor(16., grad_fn=<DotBackward>)
# net(
#   (fc1): Linear(in_features=1, out_features=2, bias=False)
#   (fc2): Linear(in_features=2, out_features=1, bias=False)
# )
# gradient input: (tensor([2., 2.]), tensor([4., 4.]))
# gradient output: (tensor(1.),)

根据输出结果可知,register_forward_hook作用于复合module,只能得到最顶层复合module的输入和输出信息;register_backward_hook作用于复合module时,只能得到最后一个module的梯度信息。
另外,还有一个针对module的hook函数为register_forward_pre_hook(hook_fn),它的hook_fn作用于module对输入做前向运算之前,因此可改变module的输入。hook_fn的函数签名为:

hook_fn(module, input) -> Tensor or None

register_forward_pre_hook和register_forward_hook的区别是:前者作用于前向运算之前,可修改module的输入;后者作用于前向运算之后,可修改module向后传播的值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值