前言
针对结构中定义了多个nn.sequential的网络模型,无法直接获取其内部某一中间层的输出,本文将给出两个方法进行解决。
方法
1 逐层进行forward
创建自定义函数,实现按照执行顺序逐层前向执行网络模型。
-----将模型输入以及模型作为参数传入函数,返回目标结果
def extract_res(inp, model):
for index, module in enumerate(model.modules()): # 按执行顺序遍历网络各层操作
if index in [0, 1, ...]: # 去除非操作层
continue
inp = module(inp) # 逐层前向执行,得到结果
if index == 3: # 判断是否为目标层 (示例为索引为3的操作)
return inp
tip: 利用.modules()在进行遍历操作时,其顺序为:
【总网络结构–>各部分–>各部分内部】
==》可见index = 0,1对应为非操作层,需要避免其执行forward。
故在使用此方法时,需要注意摒弃非操作层,跳过执行。此外,须推理出目标输出层对应索引号,才能实现精准获取。
2 使用hook函数
(1)定义保存hook内容的对象类
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
(2) 为卷积层注册hook
hook_handles = []
save_output = SaveOutput()
for layer in model.modules(): # 按执行顺序遍历网络各层操作
if isinstance(layer, nn.Linear): # 按操作指令进行判别
handle = layer.register_forward_hook(save_output)
hook_handles.append(handle)
代码中示例即为寻找所有为执行nn.Linear()的操作层
(3) 对输入x进行预测(过程中每计算一个输出将自动调用hook函数)
out = model(x)
(4)取出通过目标层的输出
data = save_output.outputs[2] # 2为目标操作层输出在最终结果列表中的索引
tip: 网络包含几个操作同名层,save_output.outputs的size就为多少,取出对应位置的输出即可
------tbc-------
有用可以点个大拇指哦 🤭