目录
2. register_forward_pre_hook(hook)
1. hook作用
hook是一个可调用的对象,它预定义了函数声明(即函数参数,返回值,调用方式等)。当调用forward() / backward()时,module对应的输入输出都会传到hook上,并可以在hook中处理这些输入输出。因此hook可中进行一些如:可视化中间特征、冻结部分层的等操作
2. register_forward_pre_hook(hook)
- 该函数在foward()之前运行
- 该函数能够修改输入并将修改后的新的输入结果返回给 forward()
- 如果想要移除hook函数可以使用 remove()
3. register_forward_hook()
- 该函数在foward()之后运行
- 该函数能够修改输出结果 (inplace)
- 如果想要移除hook函数可以使用 remove()
4. 示例
import torch
import torch.nn as nn
class SumNet(nn.Module):
def __init__(self):
super().__init__()
def forward(self, a, b, c):
print("forward")
return a + b + c
# 参数中的input为model(a, b, c)函数中传进来的a, b, c, return值为修改后的输入结果
# 同时该结果作为forward()函数的输入
def forward_pre_hook(module, input):
print("forward_pre_hook")
a, b, c = input
return a + 10, b, c
# 参数中的input为a, b, c
# 参数中的output为forward()输出的结果, return值为修改后的模型输出结果
def forward_hook(module, input, output):
print("forward_hook")
return output + 100
model = SumNet()
# 通过调用register_forward_pre_hook/register_forward_hook 注册hook函数
model.register_forward_pre_hook(forward_pre_hook)
model.register_forward_hook(forward_hook)
a = torch.tensor(1, dtype=torch.float, requires_grad=True)
b = torch.tensor(2, dtype=torch.float, requires_grad=True)
c = torch.tensor(3, dtype=torch.float, requires_grad=True)
d = model(a, b, c)
print(d)
输出结果如下。从输出的顺序可以看到先执行forward_pre_hook(),然后执行forward(),最后执行
forward_hook()。在执行forward_pre_hook()时,将输入的a, b, c = 1, 2, 3修改成了 a, b, c = 11, 2, 3.之后将修改后的a, b, c传递给 forward()函数作为输入,得到 a + b + c = 16。最后forward_hook()将16+100并返回给模型,因此模型得到最后的输出结果116。
forward_pre_hook
forward
forward_hook
tensor(116., grad_fn=<AddBackward0>)
5. 其他示例
上述示例只展示了钩子函数的执行顺序,具体的可视化特征案例可参考这个网址,讲解和示例都很好:Pytorch里hook的介绍 - 简书
注意:为什么本文没有介绍 register_backward_hook(hook) 函数,因为官网中说到该函数目前存在bug,所以不建议使用
本文内容参考部分博客和视频: