18hook函数与CAM可视化

本文介绍了PyTorch中的Hook函数机制,包括Tensor和Module的四种Hook,以及它们在特征提取和可视化中的应用。特别讨论了CAM和Grad-CAM两种视觉解释方法,用于理解深度网络关注的图像区域。通过Hook函数,可以捕获中间特征图并进行可视化,而Grad-CAM通过梯度加权得到更精确的关注区域。
摘要由CSDN通过智能技术生成

一、Hook函数概念

1.1 Hook引入的原因

Pytorch的运行机制是动态计算图,动态图运算结束后,一些中间变量(如feature map和非叶子结点的梯度)会被释放掉,但是往往有时候我们需要获取这些中间变量,这时就可以通过Hook函数在主体中根据Hook机制添加额外的函数来获取或改变中间变量

1.2 Hook函数机制

Hook函数机制: 不改变主体(前向传播和后向传播),实现额外功能,像一个挂件,挂钩, hook

在这里插入图片描述
nn.module中的call()函数的运行机制也正是hook函数机制,整个call函数分为四个部分,分别是:

  • forward_pre_hook
  • forward
  • forward_hook
  • backward_hook

如上图所示,call()函数执行forward_pre_hook函数,然后执行forward前向传播过程,接着执行forward_hook函数,最后执行back_forward函数
所以,在前向传播过程中,不仅仅只是单纯地执行前项传播,而是会提供hook函数接口,来实现额外的操作和功能

1.3 四种hook函数

主要分为三类:针对tensor的,前向传播的,和后向传播的

  1. torch.Tensor.register_hook(hook)
  2. torch.nn.Module.register_forward_hook
  3. torch.nn.Module.register_forward_pre_hook
  4. torch.nn.Module.register_backward hook

二、Hook函数与特征提取

2.1 Tensor.register_hook

hook(grad)

功能: 注册一个反向传播hook函数

Hook函数仅一个输入参数,为张量的梯度,返回张量或者无返回

示例:通过hook函数获取和改变非叶子结点的梯度
在这里插入图片描述

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()        # 存储张量的梯度

    def grad_hook(grad):
        a_grad.append(grad)

    handle = a.register_hook(grad_hook) # 把定义的函数注册到对应的张量上

    y.backward()

    # 查看梯度
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]: ", a_grad[0])
    handle.remove()


# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()

    def grad_hook(grad):               # 定义hook函数修改张量梯度
        grad *= 2
        return grad*3                  # 通过return返回的梯度会覆盖掉原梯度

    handle 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值