引言
本文想要解决的是pytorch中间层的输出问题,有时我们训练神经网络时会设定回归或者分类作为目标,但在测试阶段实际需要的只是用神经网络提取输入的表征,因此需要获得网络的中间层输出。总结起来有两种方法:
- 一是在前向传播过程中通过"钩子"截取;
- 二是将中间层输出也视作输出,在最后获取
一、钩子截流
这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。给出代码如下:
import torch
from torch import nn
class test_model(nn.Module):
def __init__(self):
super(test_model, self).__init__()
self.conv_16 = nn.Sequential(
nn.Conv2d(1,16,(3,3),(1,1)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2,2))
)
self.conv_32 = nn.Sequential(
nn.Conv2d(16,32,(3,3),(1,1)),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.linear_1 = nn.Sequential(
nn.Linear(32,64),
nn.ReLU()
)
self.linear_class = nn.Sequential(
nn.Linear(64,5),
nn.ReLU()
)
def forward(self, x):
x = self.conv_16(x)
x = self.conv_32(x)
x = x.view(x.shape[0],x.shape[1])
x = self.linear_1(x)
return self.linear_class(x)
features = []
def hook(module, input, output):
features.append(input)
return None
net = test_model()
# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():
print(name)
# 设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
net(a)
print(features)
实际过程中建议先打印所有层的名字以做到精确提取。
附:钩子函数
值得注意的是这个函数的用途并不止于提取中间层的输出,它也可以用于对module的输出值进行修改,查看该函数的源码注释
r"""Registers a forward hook on the module.
The hook will be called every time after :func:
forward
has
computed an output. It should have the following signature::hook(module, input, output) -> None or modified output
The input contains only the positional arguments given to the
module. Keyword arguments won’t be passed to the hooks and only to
theforward
. The hook can modify the output. It can modify the
input inplace but it will not have effect on forward since this is
called after :func:forward
is called.
我们可以分析得到:
register_hook()函数在对应module前向传播产生输出后自动执行,回调函数的输入只包括了module的位置参数不包括关键字参数。回调函数可以通过return 修改过的输出值来对module的最终输出进行修改,同样在回调函数内部我们也可以对输入进行inplace修改,但并不会对module的输出值造成影响因为register_hook()是在对应module前向传播产生输出后之执行,输入值已经被计算过了。(这里本人对最后一句话的理解与参考中的不一样,但未经验证过。)
二、视作输出
通过返回值提取中间层输出比较简单,同样有两种方法来实现:
- 一是将中间层的返回值作为模型的属性,在初始化时定义好;
- 二是在forward函数将中间层返回值一并输出;
代码如下:
import torch
from torch import nn
class test_model(nn.Module):
def __init__(self):
super(test_model, self).__init__()
self.conv_16 = nn.Sequential(
nn.Conv2d(1,16,(3,3),(1,1)),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2,2))
)
self.conv_32 = nn.Sequential(
nn.Conv2d(16,32,(3,3),(1,1)),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.linear_1 = nn.Sequential(
nn.Linear(32,64),
nn.ReLU()
)
self.linear_class = nn.Sequential(
nn.Linear(64,5),
nn.ReLU()
)
self.feature=[]
def forward(self, x):
x = self.conv_16(x)
x = self.conv_32(x)
x = x.view(x.shape[0],x.shape[1])
x = self.linear_1(x)
self.feature.append(x.detach())
feature = x.detach()
return self.linear_class(x),feature
features = []
def hook(module, input, output):
features.append(input)
return None
net = test_model()
# 确定想要提取出的中间层名字
for (name, module) in net.named_modules():
print(name)
# 设置钩子
net.linear_class[0].register_forward_hook(hook)
a = torch.randn((3,1,28,28))
_,final_out=net(a)
hook_out=features
att_out=net.feature
对比可以发现三者输出是一致的。