pytorch默认只保存最后一层的输出,中间层输出默认不保存,要提取中间层网络输出值,需要使用回调函数register_forward_hook(),通过传入处理函数,便可以提取和保存特点网络层的输出值。
class ActivationData():
#网络输出值
outputs = None
def __init__(self,layer):
#在模型的layer_num层上注册回调函数,并传入处理函数hook_fn
self.hook = layer.register_forward_hook(self.hook_fn)
def hook_fn(self,module,input,output):
self.outputs = output.cpu()
def remove(self):
#由回调句柄调用,用于将回调函数从网络层删除
self.hook.remove()
#获取第二个卷积层
conv_out = ActivationOutputData(model.conv2)
#传入图片
o = model(img)
#移除回调函数
conv_out.remove()
#输出图片
for i in range(16):
ax.imshow(conv_out.outputs[0][i].detach().numpy())
资料来源:《pytorch深度学习实战 从新手小白到数据科学家》