pytorch中间层输出方法

引言

 本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的中间层输出。总结起来有两种方法:

  • 一是在前向传播过程中通过"钩子"截取;
  • 二是将中间层输出也视作输出,在最后获取

一、钩子截流

 这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。给出代码如下:

import torch
from torch import  nn

class test_model(nn.Module):
    def __init__(self):
        super(test_model, self).__init__()

        self.conv_16 = nn.Sequential(
            nn.Conv2d(1,16,(3,3),(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2))
        )
        self.conv_32 = nn.Sequential(
            nn.Conv2d(16,32,(3,3),(1,1)),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )

        self.linear_1 = nn.Sequential(
            nn.Linear(32,64),
            nn.ReLU()
        )
        self.linear_class = nn.Sequential(
            nn.Linear(64,5),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.conv_16(x)
        x = self.conv_32(x)
        x = x.view(x.shape[0],x.shape[1])
        x = self.linear_1(x)
        return self.linear_class(x)

features = []

def hook(module, input, output):
    features.append(input)
    return None

net = test_model()

# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():
    print(name)
#  设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
net(a)
print(features)

实际过程中建议先打印所有层的名字以做到精确提取。

附:钩子函数

 值得注意的是这个函数的用途并不止于提取中间层的输出,它也可以用于对module的输出值进行修改,查看该函数的源码注释

r"""Registers a forward hook on the module.

The hook will be called every time after :func:forward has
computed an output. It should have the following signature::

   hook(module, input, output) -> None or modified output

The input contains only the positional arguments given to the
module. Keyword arguments won’t be passed to the hooks and only to
the forward. The hook can modify the output. It can modify the
input inplace but it will not have effect on forward since this is
called after :func:forward is called.

我们可以分析得到:
 register_hook()函数在对应module前向传播产生输出后自动执行,回调函数的输入只包括了module的位置参数不包括关键字参数。回调函数可以通过return 修改过的输出值来对module的最终输出进行修改,同样在回调函数内部我们也可以对输入进行inplace修改,但并不会对module的输出值造成影响因为register_hook()是在对应module前向传播产生输出后之执行,输入值已经被计算过了。(这里本人对最后一句话的理解与参考中的不一样,但未经验证过。)

二、视作输出

 通过返回值提取中间层输出比较简单,同样有两种方法来实现:

  • 一是将中间层的返回值作为模型的属性,在初始化时定义好;
  • 二是在forward函数将中间层返回值一并输出;

代码如下:

import torch
from torch import  nn

class test_model(nn.Module):
    def __init__(self):
        super(test_model, self).__init__()

        self.conv_16 = nn.Sequential(
            nn.Conv2d(1,16,(3,3),(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2,2))
        )
        self.conv_32 = nn.Sequential(
            nn.Conv2d(16,32,(3,3),(1,1)),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )

        self.linear_1 = nn.Sequential(
            nn.Linear(32,64),
            nn.ReLU()
        )
        self.linear_class = nn.Sequential(
            nn.Linear(64,5),
            nn.ReLU()
        )

        self.feature=[]

    def forward(self, x):
        x = self.conv_16(x)
        x = self.conv_32(x)
        x = x.view(x.shape[0],x.shape[1])
        x = self.linear_1(x)
        self.feature.append(x.detach())
        feature = x.detach()
        return self.linear_class(x),feature

features = []

def hook(module, input, output):
    features.append(input)
    return None

net = test_model()

# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():
    print(name)
#  设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))

_,final_out=net(a)
hook_out=features
att_out=net.feature

对比可以发现三者输出是一致的。

参考

Pytorch获取中间层输出的几种方法
pytorch的hook机制之register_forward_hook

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值