pytorch中register_hook以及register_forward_hook

何为叶子节点和非叶子节点

在理解register_hook之前,首先得搞懂什么叶子节点和非叶子节。简单来说叶子节点是有梯度且独立得张量,例如a = torch.tensor(2.0,requires_grad=True),b= torch.tensor(3.0,requires_grad=True),非叶子节点是依赖其他张量而得到得张量如c = a+b
判断是叶子节点还是非叶子节点可以使用 is_leaf来判断一个张量是叶子节点还是非叶子节点。

import torch
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.is_leaf)
print(b.is_leaf)
c = a +b 
print(c.is_leaf)

>>> True
>>> True
>>> False

中间张量 c 作为非叶子节点是没有梯度信息得。pytorch默认在梯度反向传播过程中不会记录中间变量梯度信息。而且叶子节点的梯度信息在反向传播流过程中是不允许我们修改的。只能通过print(a.grad)查看张量的梯度信息。
那么,如果我们想查看中间变量 c 以及想改变叶子节点反向传播过程中的梯度值,应该怎么办呢。这时候就要使用register_hook这个钩子函数了。通过一下两段代码看一下钩子函数的主要作用。

register_hook
a = torch.tensor(2.0,requires_grad=True)
b = torch.tensor(3.0,requires_grad=True)
print(a.grad)
print(b.grad)
c = a*b
print(c.grad)  # 由于c是叶子节点,所以他是不记录梯度信息得。前后打印梯度信息都为None

d = torch.tensor(4.0,requires_grad=True)
e = c * d
e.backward()
print(a.grad)
print(b.grad)
print(c.grad)

>>>输出
None
None
None
tensor(12.)
tensor(8.)
None

通过上面代码可以看出,c作为中间变量在反向传播过程中不记录梯度信息。c=a*b其中a的梯度就为b的值,b的梯度就是a的值。接下来对中间变量c 使用register_hook,这个函数传入的参数得是一个函数。

import torch

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b

def c_hook(grad):
    print("c_hook",grad)
    return grad + 2    # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。

# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad()   # 存储中间变量的梯度

print(a.grad)
print(b.grad)
print(c.grad)

c.backward()

print(a.grad)
print(b.grad)
print(c.grad)

>>>
None
None
None
c_hook tensor(1.)
hello my grad is tensor(3.)
tensor(9.)
tensor(6.)
tensor(3.)

为什么输出会是这样的结果呢,一个张量可以注册多个钩子函数,反向传播过程中按照注册的顺序依次运行。 c.register_hook(c_hook) c.register_hook(lambda grad:)
,这两个函数可以重写c的梯度,第一个函数传入的参数是c的梯度,自身对自身的梯度pytorch中默认为1。所以此时c_hook中传入的grad=1,这个函数返回值为grad+2=3,此时会重写中间变量c的梯度信息。第二个钩子函数传入的函数为匿名函数,这个匿名函数对c的梯度没有进行重写,使用的还是上一个钩子函数重写的值,此使打印信息就为3。最后通过c.retain_grad()记c的梯度信息。通过这个例子,我稍微懂了点register_hook这个钩子函数的作用,是不是本来不可修改的梯度信息值,通过这个函数修改了呢。

通过一下这个例子比较再来看一下registe_hook函数的作用。

import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b


def c_hook(grad):
    print("c_hook",grad)
    return grad + 2    # 什么也不返回的话用的是和之前一样的梯度,不对其进行变化。

# 在c中,钩子按照有序字典的方式存储,按照存储的前后一次调用
c.register_hook(c_hook)
c.register_hook(lambda grad: print("hello my grad is",grad))
c.retain_grad()   # 存储中间变量的梯度

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100)  # 将使用100+grad代替本来返回得梯度值

e = c * d

print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)


# e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

e.backward()

print(a.grad)
print(b.grad)
print(c.grad)
print(d.grad)
print(e.grad)

>>>输出
None
None
None
None
None
c_hook tensor(8.)
hello my grad is tensor(10.)
tensor(30.)
tensor(20.)
tensor(10.)
tensor(112.)
tensor(2.)

这段代码前部分和前面的代码保持一致,后面添加了e = c * d,在反向传播前,毋庸置疑a,b,,c,d,e的梯度都为None。反向传播过程中首先看 e,自身对自身的倒数默认为1,但是e注册的钩子将对原本的梯度 * 2 ,来替代原先的梯度信息,所以打印出的e的梯度信息为2。相应的,e 对 c的梯度信息相应的就变为 2d=8,e对d的梯度信息就变为 2c=12,案例说此使d的梯度信息为12,为什么是112呢,可以看出d注册了一个钩子函数,这个钩子给d原本的梯度信息加了100,来代替旧的梯度信息,所以d的梯度信息为112。由于c注册的钩子函数给他加了2,所以c的梯度信息为10。相应的a b 的梯度就都要乘以c 的梯度信息了。 同样,原本不变的梯度信息值在这里都根据register_hook这个函数相应的被重写。
在这里插入图片描述
以上就是我根据视频链接对register_hook的理解。

register_forward_hook

register_forward_hook register_forward_pre_hook这个函数主要使用在nn.Module网络中。
第一个函数看名称是用在网络forward之前,第二个是运行在forward之后,举例:

import torch
import torch.nn as nn


class SumNet(nn.Module):
    def __init__(self):
        super(SumNet, self).__init__()

    @staticmethod
    def forward(a, b, c):
        d = a + b + c

        print('forward():')
        print('    a:', a)
        print('    b:', b)
        print('    c:', c)
        print()
        print('    d:', d)
        print()

        return d


def forward_pre_hook(module, input_positional_args):
    a, b, c = input_positional_args
    new_input_positional_args = a + 10, b,c+10

    print('forward_pre_hook():')
    print('    module:', module)
    print('    input_positional_args:', input_positional_args)
    print()
    print('    new_input_positional_args:', new_input_positional_args)
    print()

    return new_input_positional_args


def forward_hook(module, input_positional_args, output):
    new_output = output + 100

    print('forward_hook():')
    print('    module:', module)
    print('    input_positional_args:', input_positional_args)
    print('    output:', output)
    print()
    print('    new_output:', new_output)
    print()

    return new_output


def main():
    sum_net = SumNet()
    sum_net.register_forward_pre_hook(forward_pre_hook)
    sum_net.register_forward_hook(forward_hook)

    a = torch.tensor(1.0, requires_grad=True)
    b = torch.tensor(2.0, requires_grad=True)
    c = torch.tensor(3.0, requires_grad=True)

    print('start')
    print()
    print('a:', a)
    print('b:', b)
    print('c:', c)
    print()
    
    print('before model')
    print()

    d = sum_net(a, b, c)   # 前向传播得时候钩子函数起作用了,先是forward_pre_hook,接下来是forward,接下来是forward_hook函数。

    print('after model')
    print()
    print('d:', d)


if __name__ == '__main__':
    main()

输出信息:

start

a: tensor(1., requires_grad=True)
b: tensor(2., requires_grad=True)
c: tensor(3., requires_grad=True)

before model

forward_pre_hook():
    module: SumNet()
    input_positional_args: (tensor(1., requires_grad=True), tensor(2., requires_grad=True), tensor(3., requires_grad=True))

    new_input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))

forward():
    a: tensor(11., grad_fn=<AddBackward0>)
    b: tensor(2., requires_grad=True)
    c: tensor(13., grad_fn=<AddBackward0>)

    d: tensor(26., grad_fn=<AddBackward0>)

forward_hook():
    module: SumNet()
    input_positional_args: (tensor(11., grad_fn=<AddBackward0>), tensor(2., requires_grad=True), tensor(13., grad_fn=<AddBackward0>))
    output: tensor(26., grad_fn=<AddBackward0>)

    new_output: tensor(126., grad_fn=<AddBackward0>)

after model

d: tensor(126., grad_fn=<AddBackward0>)

分析以上为什么会输出这样的结果,前面提到register_forward_hook这个函数会在网络前向传播前运行,需要两个参数modul 和 input案例中输入为 tensor 1 2 3,经过这个函数给2 3 分别加了10,并且返回了一组新的值,这组值是要传入forward中,可以看出,forward函数打印的a b c 为传入的这组新值,而不是刚开始定义的1 2 3,forward函数运行过程中返回每层的输出会运行forward_hook函数。这个函数主要需要三个参数,module input output
以下从Lenet网络来使用这个函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)     
        out = F.max_pool2d(out, 2)      
        
        out = self.conv2(out)
        out = F.relu(out)  
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out
model = LeNet()

# 分别对model的第一个卷积层和最后一层使用了钩子函数,这样既可以取出对应层的输出。
def hook(model,input_,output):
    print("最后一层输出:",output.shape)

def conv_hook(model,input_,output):
    print("conv1后",input_[0].shape,output.shape)

model.register_forward_hook(hook)
model.conv1.register_forward_hook(conv_hook)


img = torch.randn([1,3,32,32])
out_put = model(img)

>>>
conv1后 torch.Size([1, 3, 32, 32]) torch.Size([1, 6, 28, 28])
最后一层输出: torch.Size([1, 10])

基于上可以看出给不同层使用钩子函数,可以提取出每一层的输出,并进行相应的处理。

以上就是pytorch中register_hookregister_forward_hook的基本理解。
如果有问题烦请指出加以改正。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值