看很多人都说这个类的作用是取出return_layers中所指定的层及其输出,这样其实是比较片面的,在FCN中调用这个class的时候这样解释不通,想了半天才想明白。本篇blog也是以FCN中源码来讲的。以下先放一下FCN的网络结构:
首先来看一下在哪里调用了IntermediateLayerGetter()这个类:
def fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):
# 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'
# 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'
backbone = resnet101(replace_stride_with_dilation=[False, True, True])
if pretrain_backbone:
# 载入resnet101 backbone预训练权重
backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))
out_inplanes = 2048
aux_inplanes = 1024
return_layers = {'layer4': 'out'}
if aux:
return_layers['layer3'] = 'aux'
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)#这里用到了
aux_classifier = None
# why using aux: https://github.com/pytorch/vision/issues/4292
if aux:
aux_classifier = FCNHead(aux_inplanes, num_classes)
classifier = FCNHead(out_inplanes, num_classes)
model = FCN(backbone, classifier, aux_classifier)
return model
这里再先贴一下IntermediateLayerGetter()类的源码,以便查看:
class IntermediateLayerGetter(nn.ModuleDict):
"""
Module wrapper that returns intermediate layers from a model
It has a strong assumption that the modules have been registered
into the model in the same order as they are used.
This means that one should **not** reuse the same nn.Module
twice in the forward if you want this to work.
Additionally, it is only able to query submodules that are directly
assigned to the model. So if `model` is passed, `model.feature1` can
be returned, but not `model.feature1.layer2`.
Args:
model (nn.Module): model on which we will extract the features
return_layers (Dict[name, new_name]): a dict containing the names
of the modules for which the activations will be returned as
the key of the dict, and the value of the dict is the name
of the returned activation (which the user can specify).
"""
_version = 2
__annotations__ = {
"return_layers": Dict[str, str],
}
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")
orig_return_layers = return_layers
return_layers = {str(k): str(v) for k, v in return_layers.items()}
# 重新构建backbone,将没有使用到的模块全部删掉
layers = OrderedDict() # 创建有序字典
for name, module in model.named_children(): # 遍历子模块的名称和模块数据
layers[name] = module # 存储在有序字典中
if name in return_layers:
del return_layers[name]
if not return_layers:
break
super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = orig_return_layers
def forward(self, x: Tensor) -> Dict[str, Tensor]:
out = OrderedDict()
for name, module in self.items():
x = module(x)
if name in self.return_layers:
out_name = self.return_layers[name]
out[out_name] = x
return out
首先我们可以看到,调用IntermediateLayerGetter()的地方其实是想重构backbone,return_layers传入的是{layer4:out}或者{'layer3':'aux','layer4':'out'},在这个类的定义源码中最难理解的就是del return_layers[name]这句。
return_layers其实是一个提示器的作用,因为在调用IntermediateLayerGetter(),传入的是resnet101(即backbone参数)和return_layers,在类定义源码里我做了注释“重新构建backbone,将没有使用到的模块全部删掉”这句,首先遍历各子模块的名称和模块数据,并有序存储在有序字典layers中,如果name在我们传入的return_layers中,那么就删除return_layers中该键值对(这里是重点!),为什么是删除return_layers中的键值对?因为return_layers是提示作用!比如return_layers中是{'layer3':'aux','layer4':'out'},我们先把layer3存储到了有序字典layers中,然后发现layer3存在于return_layers中,删掉该键值对后return_layers={‘layer4’:‘out’},return_layers不为空,继续for循环;然后遍历到了layer4,把layer4存储到了有序字典layers中,然后发现layer4存在于return_layers中,删掉该键值对后return_layers为空,break掉for循环,那么此时你有没有发现有序字典layers存储的就是resnet101中layer4及其之前的所有层!所以return_layers就是这样一个提示作用!即刚开始建的空的有序字典只存储resnet101中layer4及其之前的所有层,不会再添加layer4后面的层了!