Pytorch 容器 - 4. Module中hook注册:register_forward_pre_hook(),register_forward_hook()

目录

1. hook作用

2. register_forward_pre_hook(hook)

3. register_forward_hook()

4. 代码示例

5. 其他示例


1. hook作用

hook是一个可调用的对象,它预定义了函数声明(即函数参数,返回值,调用方式等)。当调用forward() / backward()时,module对应的输入输出都会传到hook上,并可以在hook中处理这些输入输出。因此hook可中进行一些如:可视化中间特征、冻结部分层的等操作

2. register_forward_pre_hook(hook)

  • 该函数在foward()之前运行
  • 该函数能够修改输入并将修改后的新的输入结果返回给 forward()
  • 如果想要移除hook函数可以使用 remove()

3. register_forward_hook()

  • 该函数在foward()之后运行
  • 该函数能够修改输出结果 (inplace)
  • 如果想要移除hook函数可以使用 remove()

4. 示例

import torch
import torch.nn as nn


class SumNet(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b, c):
        print("forward")
        return a + b + c

# 参数中的input为model(a, b, c)函数中传进来的a, b, c, return值为修改后的输入结果
# 同时该结果作为forward()函数的输入
def forward_pre_hook(module, input):
    print("forward_pre_hook")
    a, b, c = input
    return a + 10, b, c

# 参数中的input为a, b, c
# 参数中的output为forward()输出的结果, return值为修改后的模型输出结果
def forward_hook(module, input, output):
    print("forward_hook")
    return output + 100


model = SumNet()
# 通过调用register_forward_pre_hook/register_forward_hook 注册hook函数
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_hook(forward_hook)

a = torch.tensor(1, dtype=torch.float, requires_grad=True)
b = torch.tensor(2, dtype=torch.float, requires_grad=True)
c = torch.tensor(3, dtype=torch.float, requires_grad=True)
d = model(a, b, c)
print(d)

输出结果如下。从输出的顺序可以看到先执行forward_pre_hook(),然后执行forward(),最后执行
forward_hook()
。在执行forward_pre_hook()时,将输入的a, b, c = 1, 2, 3修改成了 a, b, c = 11, 2, 3.之后将修改后的a, b, c传递给 forward()函数作为输入,得到 a + b + c = 16。最后forward_hook()将16+100并返回给模型,因此模型得到最后的输出结果116。

forward_pre_hook
forward
forward_hook
tensor(116., grad_fn=<AddBackward0>)

 5. 其他示例

上述示例只展示了钩子函数的执行顺序,具体的可视化特征案例可参考这个网址,讲解和示例都很好:Pytorch里hook的介绍 - 简书

注意:为什么本文没有介绍 register_backward_hook(hook) 函数,因为官网中说到该函数目前存在bug,所以不建议使用

本文内容参考部分博客和视频:

[1]  Pytorch里hook的介绍 - 简书

[2]  [中文字幕] 深入理解 PyTorch 中的 Hook 机制_哔哩哔哩_bilibili

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值