前言
hook获取中间层输出
参考https://zhuanlan.zhihu.com/p/87853615
一、定义hook
# 用于存储各层的输入输出
module_name = []
features_in_hook = []
features_out_hook = []
def hook(module, fea_in, fea_out):
module_name.append(module.__class__)
features_in_hook.append(fea_in)
features_out_hook.append(fea_out)
return None
二、注册钩子
register_forward_hook()函数必须在forward()函数调用之前被使用
model=你定义的model()
print(model)
# 一:获取特定层
layer_name = 'xxx'
# xxx为对应层的名字,可以print一下model看需要获取哪些层
for (name, module) in model.named_modules():
if name == layer_name:
module.register_forward_hook(hook=hook)
# 二:获取全部层
net_children=model.children()
for child in net_children:
child.register_forward_hook(hook=hook)
model(X) # forward后可以输出相应层的输入输出。
print(features_in_hook)
print(features_out_hook)