钩子截流:
这种方式是在前向传播进行中还没得到最终输出时,将所需要的中间层输出从前向数据流中提取出来,利用到了pytorch中的register_hook()函数。这一函数可以为模型中的某个module设置一个回调函数,形如:
hook(module, input, output) -> None or modified output
函数的输入值为module的名字、module的输入和输出。通过前置定义一个数组,在hook()函数中将对应module的输入或输出加入该数组以实现中间层提取。实际过程中建议先打印所有层的名字以做到精确提取。给出代码如下:
网络:
class net1(nn.Module):
def __init__(self):
super(net1, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, bias=False),
nn.Conv2d(3, 6, kernel_size=3, stride=1, padding=0, bias=False),
)
self.conv2 = nn.Sequential(
nn.C