一、pytorch的hook机制
在神经网络的反向传播当中个,流程只保存叶子节点的梯度,对于中间变量的梯度没有进行保存,hook可以通过自定义一些函数,从而完成中间变量梯度的输出,比如中间特征图、中间层梯度修正等。pytorch有四个hook相关的函数:分别是
- register_hook 属于tensor类
- register_backward_hook 属于moudule类
- register_forward_hook 属于moudule类
- register_forward_pre_hook 属于moudule类
下面收集一些应用实例,部分例子来自于
涩醉:pytorch使用hook打印中间特征图、计算网络算力等,
Pytorch中autograd以及hook函数详解 - Oldpan的个人博客
1、register_hook 记录tensor中间变量的grad
pytorch 查看中间变量的梯度
参考链接:
pytorch 查看中间变量的梯度www.cnblogs.comgrads
2、register_backward_hook 属于moudule类
import