说明
- 在深度学习中,"钩子"通常指的是在模型训练或推理过程中插入的一些回调函数或处理程序,以执行额外的操作或监控模型的行为。这些钩子可以分为两种类型:张量钩子和模块钩子。
-
张量钩子(Tensor Hooks):
张量钩子是与模型中的具体张量(tensor)相关联的。通过在张量上注册钩子,可以在张量的计算中执行自定义的操作,例如记录梯度、修改张量的值等。这对于调试、可视化和梯度的处理非常有用。在PyTorch中,可以使用
register_hook
方法来添加张量钩子。例子:
def tensor_hook(grad): # 自定义操作,可以在这里处理梯度信息 print("梯度信息:", grad) # 注册张量钩子 tensor.register_hook(tensor_hook)
-
模块钩子(Module Hooks):
模块钩子是与模型中的具体模块(layer、block等)相关联的。通过在模块上注册钩子,可以在模块的前向或后向传播中执行自定义操作,例如获取模块的输出、记录模块的参数等。在PyTorch中,可以使用
register_forward_hook
和register_backward_hook
方法来添加模块钩子。例子:
def forward_hook(module, input, output): # 自定义前向传播操作 print("输入:", input) print("输出:", output) def backward_hook(module, grad_input, grad_output): # 自定义反向传播操作 print("梯度输入:", grad_input) print("梯度输出:", grad_output) # 注册模块钩子 module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook)
张量钩子
我们要看的第一种类型
-
紫色为梯度,e处的梯度如果没有特殊指定的话,默认值为1
-
一旦我们向后调用e点(执行到e),通过这些节点的梯度的整个计算,对我们来说是不可接近的
-
当它们流过时,我们无法真正检查梯度,或者如果我们想改变他们,我们只能看到梯度是什么,输出到叶注释
-
这就是张量上的钩子的用武之地,它们允许我们在梯度向后流过图形时检查它们,并有可能改变它们
-
当你加钩子的时候,在这里我们加入第一个钩子,我们称C点寄存器钩,我们通过它接受梯度的函数,可选地返回一个新的梯度,如果你不从这个函数返回任何东西,就用和之前一样的梯度把它传递下去
-
所以当我们注册这个钩子的时候,它首先被添加到这个c张量上的向后钩子上,这是一本有序的字典,所以你把钩子加到张量上的顺序很重要,因为在向后的图表中,它们会按这个顺序被调用
-
接下来我们再注册一个钩子,这次我们只是给它传递一个lambda函数,我就打印出一个渐变,所以它不会改变梯度
我把它打印出来
它将继续使用向后图中的前一个渐变
你可以在这里看到它在向后钩子上添加了lambda函数 -
接下来我们叫C保持,如果要在中间节点上存储渐变,所以在这个例子中,A,B和D是叶节点
-
默认情况下,它们将是唯一获得存储到它们的渐变的节点,通过这些累加梯度节点,如果我们想要一个渐变存储在中间节点上
-
接下来我们将创建d张量,然后我们在d张量上注册一个钩子,这里它只是一个lambda函数
再加一百,因为它返回一个梯度,它将替换传递给它的梯度
-
所以需要注意的是,向中间节点和叶节点添加钩子是有区别的
-
向叶子节点添加钩子时,它只是把它添加到它的向后钩子有序字典中
-
但是当您向中间节点添加钩子时,第一次将钩子添加到它的向后钩子顺序字典中,您还通知与此张量关联的向后图中的节点,在这种情况下,你要加上这个张量的后钩到这个Pre钩子的向后节点列表
模块钩子
- 我们现在来看看模块上的钩子,这些会更容易理解,首先呢,一个典型的模块将有一个正向方法
- 这里我们只接受三个输入,我们将它们相加并返回输出
- 模块钩子是添加一个函数,该函数在这个前向方法之前被调用,或者在这个前向方法之后调用的函数