pytorch: hook机制,取网络层的输入输出和梯度

前言

本篇记录在pytorch中,通过hook机制将网络某层的输入输出和梯度取出。

pytorch的hook机制

hook是pytorch中一个独特的机制,可以用于将变量的梯度、网络层的输入输出和梯度“钩出来”保存和修改。

使用hook流程:首先定义hook函数对梯度或者输入输出的操作,然后register到变量或者网络层,最后对网络推理或者变量反向传播激活hook函数生效,去除hook。

register_hook用于变量梯度操作

register_hook接收的hook函数只包含变量梯度这一个参数:

import torch

def hook_function(grad):
    grad += 1

a = torch.tensor([1, 1, 1], requires_grad=True, dtype=torch.float32)
b = a.mean()
h = a.register_hook(hook_function)
b.backward()

print(b.grad) # tensor([1.333, 1.333, 1.333])
h.remove()

尽量不要通过register_hook修改变量梯度。如果要取出变量梯度,可以定义全局变量然后把grad赋过去。

register_forward_hook用于网络层输入输出操作

register_forward_hook接收的hook函数包含网络层,层输入和层输出三个参数:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(3, 10)
        self.linear2 = nn.Linear(10, 1)
    def forward(self, x):
        return F.relu(self.linear2(F.relu(self.linear1(x))))

def hook_forward_function(module, module_input, module_output):
    print(module_input)
    print(module_output)

mlp = MLP()
h1 = mlp.linear2.register_forward_hook(hook_forward_function)
a = torch.tensor([1, 1, 1], requires_grad=True, dtype=torch.float32)
y = mlp(a)

h1.remove()

网络推理到hook注册层时,hook函数会被调用。

register_full_backward_hook用于网络层输入输出梯度操作

与forward hook类似,register_full_backward_hook也接收三参数hook函数,但分别是网络层,层输入梯度和层输出梯度。

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.linear1 = nn.Linear(3, 10)
        self.linear2 = nn.Linear(10, 1)
    def forward(self, x):
        return F.relu(self.linear2(F.relu(self.linear1(x))))

def hook_backward_function(module, module_input_grad, module_output_gard):
    print(module_input_grad)
    print(module_output_grad)

mlp = MLP()
h1 = mlp.register_full_backward_hook(hook_backward_function)
a = torch.tensor([1, 1, 1], requires_grad=True, dtype=torch.float32)
y = mlp(a)
y.backward()

h1.remove()

网络输出backward时,hook函数被调用。

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个使用PyTorchhook机制来获EfficientDate模型卷积层特征图的示例代码: ```python import torch from efficientnet_pytorch import EfficientNet # 加载EfficientDate模型 model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=10) # 定义一个列表,用于存储指定层的输出 outputs = [] # 定义hook函数,用于获指定层的输出 def hook(module, input, output): # 将输出保存到列表中 outputs.append(output) # 注册hook函数到指定层 target_layer = model._blocks[9]._depthwise_conv hook_handle = target_layer.register_forward_hook(hook) # 输入图像进行前向传播 inputs = torch.randn(1, 3, 224, 224) outputs = model(inputs) # 获指定层的输出作为特征图 feature_map = outputs[0] # 将特征图保存为图片 import matplotlib.pyplot as plt plt.imshow(feature_map.detach().numpy()[0, 0, :, :], cmap='gray') plt.show() # 移除hook函数 hook_handle.remove() ``` 在上述代码中,我们首先加载了EfficientDate模型,然后定义了一个列表`outputs`,用于存储hook函数获的指定层的输出。接着,我们定义了一个hook函数`hook`,用于将指定层的输出保存到`outputs`列表中。然后,我们通过`register_forward_hook`方法将`hook`函数注册到EfficientDate模型的第9个block的深度卷积层上,以获该层的输出。接下来,我们输入图像进行前向传播,模型会自动调用hook函数,将指定层的输出保存到`outputs`列表中。然后我们将特征图绘制成灰度图并显示。最后,我们从`outputs`列表中获指定层的输出作为特征图,并移除hook函数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值