使用hook来调取网络结构中任一层的输入输出,只需要两部分函数:
1. 定义hook函数,以model,input,output为输入
class HookTool:
def __init__(self) -> None:
self.feature = None
def hook_func(self,module,input,output):
"""
用于处理feature的hook函数必须包含三个参数[module,input,output]
module:torch里的一个子module
input:该module的输入
output:该module的输出
"""
self.feature = output
2. 调用register_forward_hook获取前向传播各层输出
def get_feature_by_hook(model):
"""
通过hook获取模型任意一层的输出
"""
feature_name = []
feature_hook = []
feature_dict = {}
for name,module in model.named_modules():
# if name in "decoder.blks.decoder_layer0.ffn.activation":
# if isinstance(module, torch.nn.Linear):
cur_hook = HookTool()
# register_forward_hook :Registers a forward hook on the module.
# The hook will be called every time after forward() has computed an output.
module.register_forward_hook(cur_hook.hook_func)
feature_name.append(name)
feature_hook.append(cur_hook)
feature_dict[name] = cur_hook
return feature_dict # feature_name,feature_hook,
3. 实际运行
# 定义网络
net = EncoderDecoder(encoder,decoder)
# 调用hook
hook_dict = get_feature_by_hook(net)
# 前向传播
output = net(xx)
# 输出每层输出
print(hook_dict.keys())
print([i.feature for i in hook_dict.values()])
此外,hook还可以追踪每层的反向传播梯度,使用 module.register_full_backward_hook