学习笔记|Pytorch使用教程22(hook函数与CAM可视化)

39 篇文章 24 订阅

学习笔记|Pytorch使用教程22

本学习笔记主要摘自“深度之眼”,做一个总结,方便查阅。
使用Pytorch版本为1.2

  • Hook函数概念
  • Hook函数与特征图提取
  • CAM (class activation map,类激活图)

一.Hook函数概念

Hook函数机制:不改变主体,实现额外功能,像一个挂件,挂钩,hook

1.torch.Tensor.register_ hook(hook)
功能:注册一个反向传播hook函数
在这里插入图片描述
Hook函数仅一个输入参数,为张量的梯度
在这里插入图片描述
计算图与梯度求导:
y = ( x + w ) ∗ ( w + 1 ) a = x + w b = w + 1 y = a ∗ b ∂ y ∂ w = ∂ y ∂ a ∂ a ∂ w + ∂ y ∂ b ∂ b ∂ w = b + 1 + a + 1 = ( w + a ) + ( x + w ) = 2 + w + x + 1 = 2 + 1 + 2 + 1 = 5 \begin{aligned} y &=(x+w) *(w+1) \\ a &=x+w \quad b=w+1 \\ y &=a * b \\ \frac{\partial y}{\partial w} &=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ &=b+1+a+1 \\ &=(w+a)+(x+w) \\ &=2+w+x+1 \\ &=2+1+2+1=5 \end{aligned} yaywy=(x+w)(w+1)=x+wb=w+1=ab=aywa+bywb=b+1+a+1=(w+a)+(x+w)=2+w+x+1=2+1+2+1=5
在反向传播结束后,非叶子节点a和b的梯度会被释放掉。现使用hook函数捕获其梯度。
测试代码:

import torch
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


# ----------------------------------- 1 tensor hook 1 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()

    def grad_hook(grad):
        a_grad.append(grad)

    handle = a.register_hook(grad_hook)

    y.backward()

    # 查看梯度
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]: ", a_grad[0])
    handle.remove()

输出:

gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]:  tensor([2.])

tensor hook尝试修改叶子节点梯度

# ----------------------------------- 2 tensor hook 2 -----------------------------------
# flag = 0
flag = 1
if flag:

    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()

    def grad_hook(grad):
        grad *= 2
        #return grad*3

    handle = w.register_hook(grad_hook)

    y.backward()

    # 查看梯度
    print("w.grad: ", w.grad)
    handle.remove()

输出(需要保留上段代码结果):

gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]:  tensor([2.])
w.grad:  tensor([10.])

如果加上 return grad*3,会覆盖原始张量梯度,输出:

gradient: tensor([5.]) tensor([2.]) None None None
a_grad[0]:  tensor([2.])
w.grad:  tensor([30.])

2.torch.nn.Module.register_forward _hook
功能:注册module的前向传播hook函数
在这里插入图片描述
参数:

  • module:当前网络层
  • input :当前网络层输入数据
  • output:当前网络层输出数据

在这里插入图片描述
测试代码:

# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 注册hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    #net.conv1.register_forward_pre_hook(forward_pre_hook)
    #net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)

    """
    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()
    """
    # 观察
    print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

输出:

output shape: torch.Size([1, 2, 1, 1])
output value: tensor([[[[ 9.]],

         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward>)

feature maps shape: torch.Size([1, 2, 2, 2])
output value: tensor([[[[ 9.,  9.],
          [ 9.,  9.]],

         [[18., 18.],
          [18., 18.]]]], grad_fn=<ThnnConv2DBackward>)

input shape: torch.Size([1, 1, 4, 4])
input value: (tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)

查看hook函数运行机制,在该处设置断点:output = net(fake_img),并进入(step into)
在这里插入图片描述
3.torch.nn.Module.register _forward_ pre_ hook
功能:注册module前向传播的hook函数
在这里插入图片描述
参数:

  • module:当前网络层
  • input :当前网络层输入数据

4.torch.nn.Module.register_backward_hook
功能:注册module反向传播的hook函数
在这里插入图片描述
参数:

  • module:当前网络层
  • grad_input :当前网络层输入梯度数据
  • grad_output :当前网络层输出梯度数据

完整代码如上述所示,现在取消注释。

# ----------------------------------- 3 Module.register_forward_hook and pre hook -----------------------------------
# flag = 0
flag = 1
if flag:

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))

    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))

    # 初始化网络
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 注册hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    net.conv1.register_forward_pre_hook(forward_pre_hook)
    net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W
    output = net(fake_img)


    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()

    # 观察
    # print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    # print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    # print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

输出:

forward_pre_hook input:(tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]]),)
backward hook input:(None, tensor([[[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000]]]]), tensor([0.5000, 0.5000]))
backward hook output:(tensor([[[[0.5000, 0.0000],
          [0.0000, 0.0000]],

         [[0.5000, 0.0000],
          [0.0000, 0.0000]]]]),)

二.Hook函数与特征图提取

测试代码:

import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools import set_seed
import torchvision.models as models

set_seed(1)  # 设置随机种子

# ----------------------------------- feature map visualization -----------------------------------
# flag = 0
flag = 1
if flag:
    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

    # 数据
    path_img = "./lena.png"     # your path to image
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]

    norm_transform = transforms.Normalize(normMean, normStd)
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm_transform
    ])

    img_pil = Image.open(path_img).convert('RGB')
    if img_transforms is not None:
        img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)    # chw --> bchw

    # 模型
    alexnet = models.alexnet(pretrained=True)

    # 注册hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())

            n1, n2 = name.split(".")

            def hook_func(m, i, o):
                key_name = str(m.weight.shape)
                fmap_dict[key_name].append(o)

            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():
        fmap = fmap_list[0]
        fmap.transpose_(0, 1)

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

使用tensorboard,在当前路径下输入:tensorboard --logdir=./runs
在浏览器中进入:http://localhost:6006/
在这里插入图片描述

三.CAM (class activation map,类激活图)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值