18hook函数与CAM可视化

本文介绍了PyTorch中的Hook函数机制,包括Tensor和Module的四种Hook,以及它们在特征提取和可视化中的应用。特别讨论了CAM和Grad-CAM两种视觉解释方法,用于理解深度网络关注的图像区域。通过Hook函数,可以捕获中间特征图并进行可视化,而Grad-CAM通过梯度加权得到更精确的关注区域。
摘要由CSDN通过智能技术生成

一、Hook函数概念

1.1 Hook引入的原因

Pytorch的运行机制是动态计算图,动态图运算结束后,一些中间变量(如feature map和非叶子结点的梯度)会被释放掉,但是往往有时候我们需要获取这些中间变量,这时就可以通过Hook函数在主体中根据Hook机制添加额外的函数来获取或改变中间变量

1.2 Hook函数机制

Hook函数机制: 不改变主体(前向传播和后向传播),实现额外功能,像一个挂件,挂钩, hook

在这里插入图片描述
nn.module中的call()函数的运行机制也正是hook函数机制,整个call函数分为四个部分,分别是:

  • forward_pre_hook
  • forward
  • forward_hook
  • backward_hook

如上图所示,call()函数执行forward_pre_hook函数,然后执行forward前向传播过程,接着执行forward_hook函数,最后执行back_forward函数
所以,在前向传播过程中,不仅仅只是单纯地执行前项传播,而是会提供hook函数接口,来实现额外的操作和功能

1.3 四种hook函数

主要分为三类:针对tensor的,前向传播的,和后向传播的

  1. torch.Tensor.register_hook(hook)
  2. torch.nn.Module.register_forward_hook
  3. torch.nn.Module.register_forward_pre_hook
  4. torch.nn.Module.register_backward hook

二、Hook函数与特征提取

2.1 Tensor.register_hook

hook(grad)

功能: 注册一个反向传播hook函数

Hook函数仅一个输入参数,为张量的梯度,返回张量或者无返回

示例:通过hook函数获取和改变非叶子结点的梯度
在这里插入图片描述

# -*- coding:utf-8 -*-

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()        # 存储张量的梯度

    def grad_hook(grad):
        a_grad.append(grad)

    handle = a.register_hook(grad_hook) # 把定义的函数注册到对应的张量上

    y.backward()

    # 查看梯度
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]: ", a_grad[0])
    handle.remove()


# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()

    def grad_hook(grad):               # 定义hook函数修改张量梯度
        grad *= 2
        return grad*3                  # 通过return返回的梯度会覆盖掉原梯度

    handle 
  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要可UNet模型中的特征图,您可以使用PyTorch中的hook函数hook函数允许您在模型的特定层上注册一个函数,该函数将在每次前向传播时被调用,并且可以访问该层的特征图。 下面是一个示例,展示了如何使用hook函数UNet模型的特征图: ```python import torch import torch.nn as nn import torchvision.utils as vutils # 定义UNet模型 class UNet(nn.Module): def __init__(self): super(UNet, self).__init__() # 定义UNet的结构 def forward(self, x): # UNet的前向传播过程 return output # 创建UNet模型实例 model = UNet() # 注册hook函数的回调函数 def hook_fn(module, input, output): # 可特征图的代码 fmap_grid = vutils.make_grid(output, normalize=True, scale_each=True) writer.add_image('feature map', fmap_grid, global_step=322) # 找到要可特征图的层 target_layer = model.conv1 # 注册hook函数 hook_handle = target_layer.register_forward_hook(hook_fn) # 执行前向传播 output = model(input) # 移除hook函数 hook_handle.remove() ``` 在上面的代码中,您需要替换`UNet`类中的代码以定义您自己的UNet模型。然后,选择要可特征图的目标层,并将其传递给`register_forward_hook`函数以注册hook函数。 在hook函数中,您可以执行特征图的可操作,并使用TensorBoard将其添加到图像中。确保根据您的设置正确导入`torch`、`torch.nn`和`torchvision.utils`模块,并将`writer`替换为您用于记录TensorBoard事件的实际写入器。 请注意,在执行完前向传播后,不要忘记使用`remove()`方法移除hook函数,以免在之后的前向传播中再次调用hook函数

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值