获取网络中间几层的结果IntermediateLayerGetter()及源码分析
思想:先创建一个model ,然后把它传入IntermediateLayerGetter中,并传入一个字典,传入字典的key是model的直接的层,传入字典的value是返回字典中的key,返回字典的value对应的是model运行的中间结果。
一个小技巧是,传入的字典期望是str–str,如果传入str–int,那么使用的方式和字典是一样的。
注意:因为 model.named_children() 只能找到直接下一层的名字,所以传入字典的key只能写直接下一层的名字。
官方帮助:
class IntermediateLayerGetter(nn.ModuleDict):
_version = 2
__annotations__ = {
"return_layers": Dict[str, str],
}
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers # 接收传入的字典
return_layers = {str(k): str(v) for k, v in return_layers.items()}
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers: # return_layers相当于一个缓存,如果缓存中的项都空了,说明只需要到这里就可以结束查找了。上面找到的layers字典已经包含了想要的所有中间层了。
break
super(IntermediateLayerGetter, self).__init__(layers) # ModuleDict父类的初始化方式
self.return_layers = orig_return_layers # 这里有一个原始的传入字典return_layers的副本,在返回的时候时使用
def forward(self, x):
out = OrderedDict()
for name, module in self.items():
x = module(x) # 运算时,还是完整的遍历了一遍net。所以只需要调用一次就行了。而不是额外调用一次net(x)
if name in self.return_layers: # 如果name在传入字典的key中
out_name = self.return_layers[name] # 返回字典的key=传入字典的value
out[out_name] = x # 返回字典的值,是当前层值
return out # 把得到的中间结果返回
# 所以关键点是,传入字典的值,正好是 层的名字。 中间结果赋给 传入字典的值为 key对应的value 中。