Pytorch:提取网络中某些层的输出

目的

我们通常在构建网络时,会使用一些比较成熟的网络构建backbone,比如ResNet、MobieNet等等。但有些时候并不需要使用整个backbone,而只需要其中某些层的输出,但自己构建一边backbone又很麻烦。

本文主要介绍这种方法就可以很方便地从一个已经搭建好的网络中方便地提取到某些层的输出。

IntermediateLayerGetter方法

参考自torchvision的实现,代码与注释如下:

class IntermediateLayerGetter(nn.ModuleDict):
    """ get the output of certain layers """
    def __init__(self, model, return_layers):
    	# 判断传入的return_layers是否存在于model中
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
		
        orig_return_layers = return_layers
        return_layers = {k: v for k, v in return_layers.items()}	# 构造dict
        layers = OrderedDict()
        # 将要从model中获取信息的最后一层之前的模块全部复制下来
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers) # 将所需的网络层通过继承的方式保存下来
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        # 将所需的值以k,v的形式保存到out中
        for name, module in self.named_children():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

使用

使用起来非常方便,首先确定好你要返回的信息在网络中的那个module,然后构造字典,k为backbone中的module名,v为返回out中的k值。示例如下:

import torchvision
    
model = torchvision.models.resnet18()
return_layers = {'layer1':'feature_1', 'layer2':'feature_2'}
backbone = IntermediateLayerGetter(model, return_layers)

backbone.eval()
x = torch.randn(1,3,224,224)
out = backbone(x)
print(out['feature_1'].shape, out['feature_2'].shape)

输出:

torch.Size([1, 64, 56, 56]) torch.Size([1, 128, 28, 28])
  • 18
    点赞
  • 78
    收藏
    觉得还不错? 一键收藏
  • 16
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 16
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值