Pytorch获取中间层信息-hook函数

参考链接:https://www.cnblogs.com/hellcat/p/8512090.html
由于pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。hook函数包括tensor的hook和nn.Module的hook,用法相似。hook函数在使用后应及时删除,以避免每次都运行钩子增加运行负载。hook函数主要用在获取某些中间结果的情景,如中间某一层的输出或某一层的梯度。这些结果本应写在forward函数中,但如果在forward函数中专门加上这些处理,可能会使处理逻辑比较复杂,这时候使用hook技术就更合适一些

Tensor对象

参考:https://pytorch.org/docs/stable/tensors.html
有如下的register_hook(hook)方法,为Tensor注册一个backward hook,用来获取变量的梯度。
hook必须遵循如下的格式:hook(grad) -> Tensor or None,其中grad为获取的梯度
具体的实例如下:

import torch

grad_list = []
def print_grad(grad):
    grad = grad * 2
    grad_list.append(grad)

x = torch.tensor([[1., -1.], [1., 1.]], requires_grad=True)
h = x.register_hook(print_grad)    # double the gradient
out = x.pow(2).sum()
out.backward()
print(grad_list)
'''
[tensor([[ 4., -4.],
        [ 4.,  4.]])]
'''
# 删除hook函数
h.remove()

Module对象

register_forward_hook(hook)register_backward_hook(hook)两种方法,分别对应前向传播和反向传播的hook函数。

register_forward_hook(hook)

在网络执行forward()之后,执行hook函数,需要具有如下的形式:

hook(module, input, output) -> None or modified output

hook可以修改input和output,但是不会影响forward的结果。最常用的场景是需要提取模型的某一层(不是最后一层)的输出特征,但又不希望修改其原有的模型定义文件,这时就可以利用forward_hook函数。

import torch
import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

features = []
def hook(module, input, output):
    features.append(output.clone().detach())


net = LeNet()
x = torch.randn(2, 3, 32, 32)
handle = net.conv2.register_forward_hook(hook)
y = net(x)

print(features[0].size())
handle.remove()

register_backward_hook(hook)

每一次module的inputs的梯度被计算后调用hook,hook必须具有如下的签名:

hook(module, grad_input, grad_output) -> Tensor or None

grad_inputgrad_output参数分别表示输入的梯度和输出的梯度,是不能修改的,但是可以通过return一个梯度元组tuple来替代grad_input
展示一个实例来解析grad_inputgrad_output参数:

import torch
import torch.nn as nn


def hook(module, grad_input, grad_output):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)


x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = nn.Linear(3, 1)
handle = module.register_backward_hook(hook)
y = module(x)
y.backward()
print('module_weight: ', module.weight.grad)

handle.remove()

输出:

grad_input:  (tensor([1.]), tensor([[ 0.1236, -0.0232, -0.5687]]), tensor([[ 1.],
        [ 2.],
        [10.]]))
grad_output:  (tensor([[1.]]),)
module_weight:  tensor([[ 1.,  2., 10.]])

可以看出,grad_input元组包含(bias的梯度输入x的梯度权重weight的梯度),grad_output元组包含输出y的梯度。
可以在hook函数中通过return来修改grad_input

import torch
import torch.nn as nn


def hook(module, grad_input, grad_output):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)
    return grad_input[0] * 0, grad_input[1] * 0, grad_input[2] * 0,


x = torch.tensor([[1., 2., 10.]], requires_grad=True)
module = nn.Linear(3, 1)
handle = module.register_backward_hook(hook)
y = module(x)
y.backward()
print('module_bias: ', module.bias.grad)
print('x: ', x.grad)
print('module_weight: ', module.weight.grad)

handle.remove()

输出:

grad_input:  (tensor([1.]), tensor([[ 0.1518,  0.0798, -0.3170]]), tensor([[ 1.],
        [ 2.],
        [10.]]))
grad_output:  (tensor([[1.]]),)
module_bias:  tensor([0.])
x:  tensor([[0., 0., -0.]])
module_weight:  tensor([[0., 0., 0.]])

对于没有参数的Module,比如nn.ReLU来说,grad_input元组包含(输入x的梯度),grad_output元组包含(输出y的梯度)。

def hook(module, grad_input, grad_output):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_output)
    return (grad_input[0] / 4, )


x = torch.tensor([-1., 2., 10.], requires_grad=True)
module = nn.ReLU()
handle = module.register_backward_hook(hook)
y = module(x).sum()
z = y * y
z.backward()

print(x.grad)  # tensor([0., 6., 6.])
handle.remove()

输出:

grad_input:  (tensor([ 0., 24., 24.]),)
grad_output:  (tensor([24., 24., 24.]),)
tensor([0., 6., 6.])

y = R e L U ( x 1 ) + R e L U ( x 2 ) + R e L U ( x 3 ) y=ReLU(x_{1})+ReLU(x_{2})+ReLU(x_{3}) y=ReLU(x1)+ReLU(x2)+ReLU(x3)
z = y 2 z=y^{2} z=y2
grad_output是传到ReLU模块的输出值的梯度,即 ∂ z ∂ y = 2 y = 24 \frac{\partial z}{\partial y}=2y=24 yz=2y=24
grad_input是进入ReLU模块的输入值的梯度,由 ∂ y ∂ x 1 = 0 , ∂ y ∂ x 2 = 1 , ∂ y ∂ x 3 = 1 \frac{\partial y}{\partial x_{1}}=0,\frac{\partial y}{\partial x_{2}}=1,\frac{\partial y}{\partial x_{3}}=1 x1y=0,x2y=1,x3y=1,可得:
∂ z ∂ y ∂ y ∂ x 1 = 0 , ∂ z ∂ y ∂ y ∂ x 2 = 24 , ∂ z ∂ y ∂ y ∂ x 3 = 24 \frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{1}}=0,\frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{2}}=24,\frac{\partial z}{\partial y}\frac{\partial y}{\partial x_{3}}=24 yzx1y=0,yzx2y=24,yzx3y=24
在hook函数中可以对输入值 x x x的梯度进行缩放:
[ 0 , 24 , 24 ] / 4 = [ 0 , 6 , 6 ] [0,24,24]/4=[0,6,6] [0,24,24]/4=[0,6,6]

  • 34
    点赞
  • 72
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
PyTorchhook是一种功能强大的工具,可用于在神经网络的不同模块之间插入自定义的操作和代码。使用hook可以实现许多有用的功能,如获取模型某一层的输出,修改某一层的输入,以及计算某一层的梯度等。 要在半小时内学会PyTorchhook,我们可以按照以下步骤进行: 1. 导入必要的库和模块: ```python import torch import torch.nn as nn ``` 2. 创建一个简单的神经网络模型: ```python class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc = nn.Linear(10, 2) def forward(self, x): x = self.fc(x) return x model = Net() ``` 3. 编写一个hook函数,在该函数中定义自定义的操作: ```python def hook_fn(module, input, output): print(f"Module: {module}") print(f"Input: {input}") print(f"Output: {output}") # 注册hook到模型的某一层 hook_handle = model.fc.register_forward_hook(hook_fn) ``` 4. 准备输入数据并运行模型: ```python input_data = torch.randn(1, 10) output = model(input_data) ``` 5. 查看hook函数输出的结果: ``` Module: Linear(in_features=10, out_features=2, bias=True) Input: (tensor([[-0.1895, -1.3554, -0.2618, -0.5179, -1.6060, 0.8815, -1.7051, 2.4338, 0.9165, -1.2528]]),) Output: tensor([[-0.1895, -0.8663]], grad_fn=<AddmmBackward>) ``` 通过上述步骤,我们成功地在半小时内学会了如何使用PyTorchhook。这个例子展示了如何注册一个forward hook来查看某一层的输入和输出。你可以根据自己的需求编写不同的hook函数,并在不同的模块上注册hook来实现自定义的操作和分析。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值