pytorch 钩子 第一部分
- 有段时间一直在纠结怎么不改变原有网络结构, 直接得到网络中间层值
- 然后发现pytorch有这种方法_ register_forward_hook()
- 下面简单的介绍其用法
def get_feature(data, model, output):
avgpool_layer = model._modules.get('avgpool')
def fun(m, i, o): output.copy_(o.data)
h = avgpool_layer.register_forward_hook(fun)
feature = model(data)
h.remove()
return feature
调用:
feature_map = torch.zeros(data.size(0), 256, 6, 6)
get_feature(data, model, feature_map)