获取网络中间几层的结果IntermediateLayerGetter()及源码分析

获取网络中间几层的结果IntermediateLayerGetter()及源码分析


思想:先创建一个model ,然后把它传入IntermediateLayerGetter中,并传入一个字典,传入字典的key是model的直接的层,传入字典的value是返回字典中的key,返回字典的value对应的是model运行的中间结果。

一个小技巧是,传入的字典期望是str–str,如果传入str–int,那么使用的方式和字典是一样的。

注意:因为 model.named_children() 只能找到直接下一层的名字,所以传入字典的key只能写直接下一层的名字。

在这里插入图片描述

官方帮助:在这里插入图片描述

class IntermediateLayerGetter(nn.ModuleDict):
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        if not set(return_layers).issubset([name for name, _ in model.named_children()]):
            raise ValueError("return_layers are not present in model")
        orig_return_layers = return_layers  # 接收传入的字典
        return_layers = {str(k): str(v) for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:  # return_layers相当于一个缓存,如果缓存中的项都空了,说明只需要到这里就可以结束查找了。上面找到的layers字典已经包含了想要的所有中间层了。
                break

        super(IntermediateLayerGetter, self).__init__(layers)  # ModuleDict父类的初始化方式
        self.return_layers = orig_return_layers  # 这里有一个原始的传入字典return_layers的副本,在返回的时候时使用
        
	def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)  # 运算时,还是完整的遍历了一遍net。所以只需要调用一次就行了。而不是额外调用一次net(x)
            if name in self.return_layers:  # 如果name在传入字典的key中
                out_name = self.return_layers[name]  #  返回字典的key=传入字典的value
                out[out_name] = x  # 返回字典的值,是当前层值
        return out  # 把得到的中间结果返回
    # 所以关键点是,传入字典的值,正好是  层的名字。  中间结果赋给 传入字典的值为  key对应的value 中。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值