Hook函数与CAM算法
Hook Function
Hook函数机制:不改变主体,实现额外功能,像一个挂件,挂钩,hook
1.torch.Tensor.register_hook(hook)
功能 注册一个反向传播hook函数
Hook函数仅一个输入参数,为张量的梯度
2.torch.nn.Module.register_forward_hood
功能 注册module 的前向传播hook 函数
参数
module 当前网络层
input 当前网络层输入数据
output 当前网络层输出数据
3.torch.nn.Module.register_forward_pre_hook
功能 注册module 前向传播前的hook 函数
参数
module 当前网络层
input 当前网络层输入数据
4.torch.nn.Module.register_backward_hook
功能 注册module 反向传播的hook 函数
参数
module 当前网络层
grad_input 当前网络层输入梯度数据
grad_output 当前网络层输出梯度数据
CAM
类激活图 class activation map
对网络的最后一个特征图进行加权求和 得到注意力机制
Grad-CAM
CAM改进版,利用梯度作为特征图权重
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ----------------------------------- 1 tensor hook 1 -----------------------------------
flag = 0
# flag = 1
if flag:
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
# 构建list 存储张量的梯度
a_grad = list()
# 把当前