最近程序调试遇到了hook函数,基础的知识不足以帮我理解hook,就此做个笔记吧!
1.hook函数概念
hook:钩子,也就常称之为钩子函数/挂钩函数
维基百科:hook函数是计算程序设计术语,指通过拦截模块间的函数调用、信息传递、事件传递来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。
具体而言,处理被拦截的函数调用、事件、消息的代码,被称之为 hook。
2.Pytorch 的 Hook 优势
在不改变网络输入输出的结构的基础上,获取、改变网络中间层变量的值和梯度。这个功能被广泛应用于可视化神经网络中间层的feature、gradient,从而诊断神经网络中可能出现的问题,分析网络的有效性。
实质:获取正向传播/反向传播的中间层的feature / gradient。
- 在forward之前注册hook,hook在forward执行以后被自动执行。
- module.register_forward_hook(hook_fn):必须在forward()函数调用之前被使用。
- f :hook_fn( module, input , output) : 变量-模块、模块的输入、模块的输出。
- module.register_backward_hook(hook_fn)
- b :hook_fn( module, grad_input , grad_output) : 变量-模块、输入端梯度、输出端梯度。
【学习资源】