torch的hook钩子函数使用

本文中,我们将介绍如何使用torch从一个torch模型中获取特定层的输出。

在Pyorch中,可以利用hook获取、改变网络中间某一层变量的值和梯度,从而便捷地分析网络,而不用专门改变网络结构,十分好用,这里只介绍如何用hook获取中间层变量值。

使用钩子来获取特定层的输出,需要执行以下步骤:
1、创建一个模型实例。
2、定义一个回调函数来获取我们感兴趣的层输出。
3、运行模型并获取特定层的输出。

第一步、创建一个模型示例

import torch
import torch.nn as nn

#initial the network
class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.out = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        output = self.out(x)
        return output
        
net = mynet()

第二步、定义一个回调函数

# 定义两个空列表,保存输入和输出的特征图
feature_in_hook = []
feature_out_hook = []

# # 定义一个回调函数,用于获取我们感兴趣的层
def hook(module, fea_in, fea_out):
    feature_in_hook.append(fea_in)
    feature_out_hook.append(fea_out)

第三步、运行模型并获取特定层的输出


net = mynet()	# 自己的模型

layer_name = ''  # 填写自己感兴趣的层名称
for (name, module) in net.named_modules():
    print(name)		# 获取每个网络层的名称
	if name==layer_name:
		# 注册钩子到我们感兴趣的层
       module.register_forward_hook(hook=hook)

# 创建随机输入
input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
out = model(input)

print(feature_in_hook)
print(feature_out_hook)
print(feature_out_hook[0].shape)

获取指定的网络层特征。

完整代码:

import torch
import torch.nn as nn


# initial the network
class mynet(nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.AvgPool2d(2, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.out = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        output = self.out(x)
        return output


net = mynet()

# 定义两个空列表,保存输入和输出的特征图
feature_in_hook = []
feature_out_hook = []


# # 定义一个回调函数,用于获取我们感兴趣的层
def hook(module, fea_in, fea_out):
    feature_in_hook.append(fea_in)
    feature_out_hook.append(fea_out)


layer_name = 'conv2.0'  # 填写自己感兴趣的层名称
for (name, module) in net.named_modules():
    print(name)  # 获取每个网络层的名称
    if name == layer_name:
        # 注册钩子到我们感兴趣的层
        module.register_forward_hook(hook=hook)

        # 创建随机输入
        input = torch.randn(1, 1, 28, 28)
        out = net(input)

        print(feature_in_hook)
        print(feature_out_hook)
        print(feature_out_hook[0].shape)


  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ghx3110

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值