hook中间层输出


前言

hook获取中间层输出
参考https://zhuanlan.zhihu.com/p/87853615


一、定义hook

# 用于存储各层的输入输出
module_name = []
features_in_hook = []
features_out_hook = []

def hook(module, fea_in, fea_out):
    module_name.append(module.__class__)
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

二、注册钩子

register_forward_hook()函数必须在forward()函数调用之前被使用

model=你定义的model()
print(model)
# 一:获取特定层
layer_name = 'xxx'
# xxx为对应层的名字,可以print一下model看需要获取哪些层
for (name, module) in model.named_modules():
    if name == layer_name:
    module.register_forward_hook(hook=hook)
    
# 二:获取全部层
net_children=model.children()
for child in net_children:
	child.register_forward_hook(hook=hook)

model(X) # forward后可以输出相应层的输入输出。
print(features_in_hook)
print(features_out_hook)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值