pytorch:如何使用hook来debug

使用hook来调取网络结构中任一层的输入输出,只需要两部分函数:

1. 定义hook函数,以model,input,output为输入

class HookTool:
    def __init__(self) -> None:
        self.feature = None

    def hook_func(self,module,input,output):
        """
        用于处理feature的hook函数必须包含三个参数[module,input,output]
        module:torch里的一个子module
        input:该module的输入
        output:该module的输出
        """
        self.feature = output

2. 调用register_forward_hook获取前向传播各层输出

def get_feature_by_hook(model):
    """
    通过hook获取模型任意一层的输出
    """
    feature_name = []
    feature_hook = []
    feature_dict = {}
    for name,module in model.named_modules():
        # if name in "decoder.blks.decoder_layer0.ffn.activation":
        # if isinstance(module, torch.nn.Linear):
        cur_hook = HookTool()
        # register_forward_hook :Registers a forward hook on the module.
        # The hook will be called every time after forward() has computed an output.
        module.register_forward_hook(cur_hook.hook_func)
        feature_name.append(name)
        feature_hook.append(cur_hook)
        feature_dict[name] = cur_hook

    return feature_dict # feature_name,feature_hook,

3. 实际运行

# 定义网络
net = EncoderDecoder(encoder,decoder)
# 调用hook
hook_dict = get_feature_by_hook(net)

# 前向传播
output = net(xx)

# 输出每层输出
print(hook_dict.keys())
print([i.feature for i in hook_dict.values()])

此外,hook还可以追踪每层的反向传播梯度,使用 module.register_full_backward_hook

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Kiki酱。

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值