FCN网络中IntermediateLayerGetter()类解析

看很多人都说这个类的作用是取出return_layers中所指定的层及其输出,这样其实是比较片面的,在FCN中调用这个class的时候这样解释不通,想了半天才想明白。本篇blog也是以FCN中源码来讲的。以下先放一下FCN的网络结构:

     首先来看一下在哪里调用了IntermediateLayerGetter()这个类:

def fcn_resnet101(aux, num_classes=21, pretrain_backbone=False):
    # 'resnet101_imagenet': 'https://download.pytorch.org/models/resnet101-63fe2227.pth'
    # 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth'
    backbone = resnet101(replace_stride_with_dilation=[False, True, True])

    if pretrain_backbone:
        # 载入resnet101 backbone预训练权重
        backbone.load_state_dict(torch.load("resnet101.pth", map_location='cpu'))

    out_inplanes = 2048
    aux_inplanes = 1024

    return_layers = {'layer4': 'out'}
    if aux:
        return_layers['layer3'] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)#这里用到了

    aux_classifier = None
    # why using aux: https://github.com/pytorch/vision/issues/4292
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    classifier = FCNHead(out_inplanes, num_classes)

    model = FCN(backbone, classifier, aux_classifier)

    return model

        这里再先贴一下IntermediateLayerGetter()类的源码,以便查看:

class IntermediateLayerGetter(nn.ModuleDict):
    """
    Module wrapper that returns intermediate layers from a model

    It has a strong assumption that the modules have been registered
    into the model in the same order as they are used.
    This means that one should **not** reuse the same nn.Module
    twice in the forward if you want this to work.

    Additionally, it is only able to query submodules that are directly
    assigned to the model. So if `model` is passed, `model.feature1` can
    be returned, but not `model.feature1.layer2`.

    Args:
        model (nn.Module): model on which we will extract the features
        return_layers (Dict[name, new_name]): a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
    """
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}

        # 重新构建backbone,将没有使用到的模块全部删掉
        layers = OrderedDict()  # 创建有序字典
        for name, module in model.named_children():  # 遍历子模块的名称和模块数据
            layers[name] = module  # 存储在有序字典中
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

首先我们可以看到,调用IntermediateLayerGetter()的地方其实是想重构backbone,return_layers传入的是{layer4:out}或者{'layer3':'aux','layer4':'out'},在这个类的定义源码中最难理解的就是del return_layers[name]这句。

return_layers其实是一个提示器的作用,因为在调用IntermediateLayerGetter(),传入的是resnet101(即backbone参数)和return_layers,在类定义源码里我做了注释“重新构建backbone,将没有使用到的模块全部删掉”这句,首先遍历各子模块的名称和模块数据,并有序存储在有序字典layers中,如果name在我们传入的return_layers中,那么就删除return_layers中该键值对(这里是重点!),为什么是删除return_layers中的键值对?因为return_layers是提示作用!比如return_layers中是{'layer3':'aux','layer4':'out'},我们先把layer3存储到了有序字典layers中,然后发现layer3存在于return_layers中,删掉该键值对后return_layers={‘layer4’:‘out’},return_layers不为空,继续for循环;然后遍历到了layer4,把layer4存储到了有序字典layers中,然后发现layer4存在于return_layers中,删掉该键值对后return_layers为空,break掉for循环,那么此时你有没有发现有序字典layers存储的就是resnet101中layer4及其之前的所有层!所以return_layers就是这样一个提示作用!即刚开始建的空的有序字典只存储resnet101中layer4及其之前的所有层,不会再添加layer4后面的层了!

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Taylor不想被展开

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值