一个例子教会你使用Hook

文章介绍了PyTorch中的Hook函数,这是一种用于调试网络和获取中间变量的工具。通过示例展示了如何注册`register_forward_hook`来打印卷积层`conv2`的前向输出。此外,还提到了其他类型的Hook,如`register_backward_hook`用于反向传播过程,以及`register_load_state_dict_post_hook`用于参数处理。
摘要由CSDN通过智能技术生成

一个例子教会你使用pytorch的Hook函数

什么是hook? pytorch提供的一个函数,Hook函数能获取网络中的中间变量,包括forward和backward中的输入输出,如:weight、grad、loss等。
为什么叫hook? Hook函数机制是不改变函数主体,实现额外功能,像一个挂件,挂钩。
使用场景? 进行网络调试的时候,我们希望打印网络中的中间变量Tesnor(网络中涉及到的Tensor),所以就出现了Hook函数。我们可以使用Hook函数优雅地(在不更改网络结构代码的情况下)打印网络中的grad、weight、以及每层的输入输出结果。

PyTorch提供的Hook函数(这里仅为示例,不止四种):

1、torch.Tensor.register_hook(hook_func)            //获取Tensor
2、torch.nn.Module.register_forward_hook(hook_func)         //获取前向过程中的中间变量
3、torch.nn.Module.register_forward_pre_hook(hook_func)        //获取前向过程前的变量
4、torch.nn.Module.register_backward_hook(hook_func)        //获取反向过程中的中间变量

下面是一个使用Hook(register_forward_hook)函数的例子:

  • 首先定义一个网络结构
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  • 针对上述定义的网络,介绍注册函数register_forward_hook和自定义函数hook_fn_forward的使用方法(打印指定层conv2的前向输出)
# 定义 forward hook function,输入是module,将其前向传播过程中,
# module的输入输出存到全局变量total_feat_in和total_feat_out中,方便后续调用或查看
total_feat_out = []
total_feat_in = []
def hook_fn_forward(module, input, output):
    print("打印层:",module) # 用于区分模块
    print(module, '层的forward input:', input) # 首先打印出来
    print(module, '层的forward output:', output)
    total_feat_in.append(input)# 然后分别存入全局 list 中
    total_feat_out.append(output) 

#定义模型
model=LeNet()

# 对conv2层注册hook
modules = model.named_children()
for name, module in modules:
    if name == "conv2":
        module.register_forward_hook(hook_fn_forward)
        #module.register_full_backward_hook(hook_fn_backward)

#构造输入,跑一遍forward前向传播
fake_img = torch.randn(4, 3, 32, 32)
output = model(fake_img)

上述代码运行结果如下:
在这里插入图片描述

  • 另外几种hook函数使用方式也类似,以下是一些简单的示例,大家可以作为参考
def grad_tensor_hook(grad):
    print("-----grad tensor:", grad)
    return grad

def hook_fn_backward(module, grad_input, grad_output):
    print(module) # 为了区分模块
    # 为了符合反向传播的顺序,我们先打印 grad_output
    print('grad_output', grad_output)
    # 再打印 grad_input
    print('grad_input', grad_input)
    
def parameter_hook(module,incompatible_keys):
    print("parameter module:", module)
    print("incompatible_keys:",incompatible_keys)
    
modules = model.named_children() #
for name, module in modules:
    module.register_full_backward_hook(hook_fn_backward)
    module.register_load_state_dict_post_hook(parameter_hook)
一个完整的hook程序的例子 一、客户端 程序命名为Client。监视系统的运行,如发现系统中有“记事本”进程(notepad.exe)或者“计算器”进程(calc.exe),立即杀死(kill)该进程,并将该事件写入数据库;定期进行检查,每间隔1分钟,检查数据库,将尚未上传的事件记录上传至服务器端。 1、目标运行环境为Windows 2000操作系统。 2、程序请设计为系统服务。 3、程序需具备抗攻击能力,包括反删除、对抗强制终止进程等功能。 (1)保持程序的持续运行,防止其他程序强行终止当前程序的运行; (2)保护事件数据库和主执行文件不被删除; (3)如发现异常(进程被终止、文件被删除等),立即强制重新载入/运行程序; (4)如连续3次发现异常,守护进程强制重新启动操作系统;重启系统后保证程序正常载入并运行。 (A)为实现以上功能,程序不限于以EXE形式实现,可根据需要自行决定实现形式。 (B)以上功能均在Windows 2000正常运行环境、Administrator权限下实现,无需考虑Windows安全模式或者权限等问题。 4、请使用简单桌面数据库,如Access、xBase等文件型数据库。 5、每次生成的事件至少包含2部分信息:事件发生时间和事件处理对象。 数据库中以表tEvent存储事件数据。tEvent表至少包含两个字段: (1)EventTime字段:时间/日期类型。记录事件发生的时间。 (2)EventTarget字段:字符类型。记录事件中所Kill的对象。需要考虑的对象有记事本进程和计算器进程。 如果需要其他表或者字段,可根据需要自行添加。 6、网络数据传输格式自定。传输的具体内容和格式请根据需要自行决定,不做具体要求。客户端网络需与服务器端网络配合工作。 7、所用开发语言与集成开发环境不限,可自行选择。 8、对于数据库连接方式,请根据需要自行选择。 二、服务器端 程序命名为Server。监听网络,一旦有客户端上传数据,立即从中提取事件信息,并在用户界面中以列表方式加以显示。 1、目标运行环境为Windows 2000操作系统。 2、程序请设计为普通Windows 2000 GUI应用程序。在用户界面中至少需包含一个事件信息列表,该列表中至少包含3部分信息:事件发生时间、事件处理对象和事件来源。 (1)事件发生时间:同客户端事件发生时间。 (2)事件处理对象:同客户端事件处理对象。 (3)事件来源:上传当前事件的客户端机器的IP地址。 3、网络数据传输格式自定;与客户端配合工作。 4、所用开发语言与集成开发环境不限,可自行选择。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值