一学就会 | PyTorch中Hook函数的使用

如果觉得本篇文章对您的学习起到帮助作用,请 点赞 + 关注 + 评论 ,留下您的足迹💪💪💪

Hook函数机制是不改变函数主体,实现额外功能,像一个挂件,挂钩。正是因为PyTorch计算图动态图的机制,所以才会有Hook函数。在动态图机制的运算,当运算结束后,一些中间变量就会被释放掉,例如,特征图,非leaf节点的梯度。但是有时候,我们需要这些中间变量,所以就出现了Hook函数。我们可以使用Hook函数获取这些中间变量。

本文相关推荐阅读:

一学就会 | PyTorch入门看这篇就够了

一学就会 | 基于PyTorch的TensorBoard可视化

一学就会 | LeNet在CIFAR10数据集上的应用

Hook函数

PyTorch提供四种Hook函数:
1、torch.Tensor.register_hook(hook)
2、torch.nn.Module.register_forward_hook
3、torch.nn.Module.register_forward_pre_hook
4、torch.nn.Module.register_backward_hook

1、torch.Tensor.register_hook

功能:注册一个反向传播hook函数,Hook函数仅一个输入参数,为张量的梯度。Hook不应修改其参数梯度值,但可以选择返回一个新的梯度,该梯度将代替grad使用。

hook(grad) -> Tensor or None

结合代码进行讲解:

import torch

# x,y 为leaf节点,也就是说,在计算的时候,PyTorch只会保留此节点的梯度值
x = torch.tensor([3.], requires_grad=True)
y = torch.tensor([5.], requires_grad=True)

# a,b均为中间值,在计算梯度时,此部分会被释放掉
a = x + y
b = x * y

c = a * b

# 新建列表,用于存储Hook函数保存的中间梯度值
a_grad = []
def hook_grad(grad):
    a_grad.append(grad)

# register_hook的参数为一个函数
handle = a.register_hook(hook_grad)
c.backward()

# 只有leaf节点才会有梯度值
print('gradient:',x.grad, y.grad, a.grad, b.grad, c.grad)
# Hook函数保留下来的中间节点a的梯度
print('a_grad:', a_grad[0])
# 移除Hook函数
handle.remove()

Out:

gradient: tensor([55.]) tensor([39.]) None None None
a_grad: tensor([15.])

2、torch.nn.Module.register_forward_hook

功能:注册module的前向传播Hook函数

参数:

  • module:当前网络层
  • input:当前网络层输入数据
  • output:当前网络层输出数据

结合代码进行讲解:

import torch
import torch.nn as nn

# 构建网网络,一个卷积层一个池化层
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x
# 初始化网络
net = Net()
# detach将张量分离
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.detach().zero_()

# 构建两个列表用于保存信息
fmap_block = []
input_block = []

def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)

# 注册Hook
net.conv1.register_forward_hook(forward_hook)

# 输入数据
fake_img = torch.ones((1, 1, 4, 4))
output = net(fake_img)

# 观察结果

# 卷积神经网络输出维度和结果
print("output share:{}\noutput value:{}\n".format(output.size(),output))

# 卷积神经网络Hook函数返回的结果
print("feature map share:{}\noutput value:{}\n".format(fmap_block[0].shape,fmap_block[0]))

# 输入的信息
print("input share:{}\ninput value:{}\n".format(input_block[0][0].size(),input_block[0][0]))

3、torch.nn.Module.register_forward_pre_hook

功能:注册module前向传播前的hook函数。

参数:

  • module:当前网络层
  • input:当前网络层输入数据

4、torch.nn.Module.register_backward_hook

功能:注册module反向传播的hook函数。

参数:

  • module:当前网络层
  • grad_input:当前网络层输入梯度数据
  • grad_output:当前网络层输出梯度数据

Hook函数进行特征提取

我们这里用一学就会 | LeNet在CIFAR10数据集上的应用训练好的模型来做实验。

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
        
def main():
   
    img_path = './car.jpg'

    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    img = Image.open(img_path)

    
    img = transform(img)
    img.unsqueeze_(dim=0)

    # 实例化
    net = LeNet()
    PATH = 'cifar_net_10.pth'
    # 将训练好的参数导入
    net.load_state_dict(torch.load(PATH))

    fmap_block = []
    input_block = []

    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)

    # 注册Hook
    net.conv1.register_forward_hook(forward_hook)
    net.conv2.register_forward_hook(forward_hook)

    with torch.no_grad():
        outputs = net(img)

         print("conv1 feature map share:{}".format(fmap_block[0].shape))

        print("conv2 feature map share:{}".format(fmap_block[1].shape))


if __name__ == '__main__':
    main()

Out:

conv1 feature map share:torch.Size([1, 6, 28, 28])
conv2 feature map share:torch.Size([1, 16, 10, 10])

如果您觉得这篇文章对你有帮助,记得 点赞 + 关注 + 评论 三连,您只需动一动手指,将会鼓励我创作出更好的文章,快留下你的足迹吧💪💪💪

  • 17
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值