什么是hook函数?

什么是hook函数?

在深度学习和机器学习中,hook函数是一种用于在特定事件发生时执行自定义操作的机制。

通俗一点讲,Hook钩子能在执行目标代码之前或之后,拦截数据或者执行逻辑,先执行自己插入的一段代码,然后再执行后续的代码。

image

如上图所示,在按顺序运行的代码块code1和code2之间加上了一个钩子,那么在时间运行的过程中,无论是否注册钩子函数(hook_fn相当于钩子上悬挂的砝码块,虚线表示可能注册了,可能没注册),都会检查钩子上是否有注册的钩子函数,如果有的话则会运行该钩子函数。

hook函数在mindspore中的作用

监控和调试

可以在模型的前向传播或反向传播过程中插入hook,以监控中间层的输出或梯度。这对于理解模型的内部工作机制非常有用。

  • register_forward_hook(hook_fn) :hook_fn (function) - 捕获Cell对象信息和正向输入,输出数据的 hook_fn 函数。
  • register_backward_hook(hook_fn) : hook_fn (function) - 捕获Cell对象信息和反向输入,输出梯度的 hook_fn 函数。
  • register_forward_pre_hook(hook_fn) :hook_fn (function) - 捕获Cell对象信息和正向输入数据的hook_fn函数。

修改梯度

在反向传播过程中,可以使用hook来修改梯度。例如,进行梯度裁剪或应用特定的正则化技术。

  • mindspore.Tensor.register_hook(hook_fn) : hook_fn (function) - 捕获Tensor反向传播时的梯度,并输出或更改该梯度的 hook_fn 函数。

检查加载的模型

用于在加载模型的状态字典后执行自定义操作。允许你在加载权重后进行一些额外的处理,比如检查模型缺失的键和多余的键,也可以将缺失的键和多余的键清空,避免模型报错。

  • register_load_param_into_net_post_hook(hook_fn) : 捕获Cell对象信息和状态字典的hook_fn函数。**

hook函数实现流程

下面以register_load_param_into_net_post_hook函数实现为例,解释hook函数的实现流程。

  1. 定义钩子函数post_load_hook()
  2. 运行注册钩子函数将post_load_hook()注册到Cell中
  3. 运行load_param_into_net函数,运行code1加载模型的状态字典,然后触发检查钩子的函数
  4. 运行run_load_param_into_net_post_hook()函数检查是否有钩子函数注册,如果有则运行
  5. 先运行注册的钩子函数,在运行后续的代码块code2

在Cell模块中定义注册钩子的函数

def register_load_param_into_net_post_hook(self, hook_fn):
		# 检查是否在图模式下,在图模式下时钩子函数不起作用
        if context._get_mode() == context.GRAPH_MODE:
            return HookHandle()
		# 检查传入的`hook_fn`是否符合要求
        if not check_hook_fn("register_load_param_into_net_post_hook", hook_fn):
            return HookHandle()
		# 获取移除钩子函数的handle
        handle = HookHandle(self._load_param_into_net_post_hook)
		# 将注册的钩子函数存入到有序字典中
        self._load_param_into_net_post_hook[handle.handle_id] = hook_fn
        return handle

在Cell模块中定义运行钩子的函数

该函数遍历注册的钩子函数并逐一运行。

def run_load_param_into_net_post_hook(self, missing_keys, unexpected_keys):
        # Note that the hook can modify missing_keys and unexpected_keys.
        incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
        for hook in self._load_param_into_net_post_hook.values():
            out = hook(self, incompatible_keys)
            assert out is None, (
                "Hooks registered with ``register_load_param_into_net_post_hook`` are not"
                "expected to return new values, if incompatible_keys need to be modified,"
                "it should be done inplace."
            )

在加载状态字典的函数中调用运行钩子的函数

def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
 	# 加载模型并检查缺失的键值(param_not_load)和多余的键值(ckpt_not_load)
    ......
	# 运行钩子函数,打印或处理缺失的键值(param_not_load)和多余的键值(ckpt_not_load)
    net.run_load_param_into_net_post_hook(param_not_load,ckpt_not_load) 
    if param_not_load and not strict_load:
        _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)

	......

    return param_not_load, ckpt_not_load

在实际应用过程中注册函数

import mindspore.nn as nn
import mindspore as ms
from mindspore import load_param_into_net

# 定义模型结构
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 32, 3)
    def construct(self, x):
        x = self.conv(x)
        x = self.conv2(x)
        return x

net = Net()

# 定义钩子函数
def post_load_hook(cell, incompatible_keys):
    print("load_param_into_net post load hook triggered.")
    print(incompatible_keys.missing_keys)
    print(incompatible_keys.unexpected_keys)
    incompatible_keys.missing_keys.clear()
    incompatible_keys.unexpected_keys.clear()

# 注册钩子函数
handle = net.register_load_param_into_net_post_hook(post_load_hook)
param_dict = {
    'conv.weight': ms.Tensor(ms.numpy.randn(20, 1, 5, 5)),
    'conv2.weight': ms.Tensor(ms.numpy.randn(32, 20, 3, 3))
}
param_list = [(k, ms.Parameter(v)) for k, v in param_dict.items()]
param_dict = {k: v for k, v in param_list}

# 运行加载模型的函数并触发钩子函数的运行
load_param_into_net(net, param_dict, strict_load = False)  

# 移除定义的钩子函数
handle.remove()
  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值