在网络中间插入模块的方法(AI学习笔记)

文章详细描述了一种嵌入网络结构,使用ResNet作为基础,包括非局部模块的插入策略以及如何在base_resnet.base.layer1和layer2中应用这些模块。重点介绍了ResNet的基本块结构和残差连接的作用。
摘要由CSDN通过智能技术生成

一、插入模块

有如下代码:

class embed_net(nn.Module):
    def __init__(self,  class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'):
        super(embed_net, self).__init__()

        self.thermal_module = thermal_module(arch=arch)
        self.visible_module = visible_module(arch=arch)
        self.base_resnet = base_resnet(arch=arch)
        self.non_local = no_local
        if self.non_local =='on':
            layers=[3, 4, 6, 3]
            non_layers=[0,2,3,0]
            self.NL_1 = nn.ModuleList(
                [Non_local(256) for i in range(non_layers[0])])
            self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
            self.NL_2 = nn.ModuleList(
                [Non_local(512) for i in range(non_layers[1])])
            self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
            self.NL_3 = nn.ModuleList(
                [Non_local(1024) for i in range(non_layers[2])])
            self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
            self.NL_4 = nn.ModuleList(
                [Non_local(2048) for i in range(non_layers[3])])
            self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])


        pool_dim = 2048
        self.l2norm = Normalize(2)
        self.bottleneck = nn.BatchNorm1d(pool_dim)
        self.bottleneck.bias.requires_grad_(False)  # no shift

        self.classifier = nn.Linear(pool_dim, class_num, bias=False)

        self.bottleneck.apply(weights_init_kaiming)
        self.classifier.apply(weights_init_classifier)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.gm_pool = gm_pool

    def forward(self, x1, x2, modal=0):
        if modal == 0:
            x1 = self.visible_module(x1)
            x2 = self.thermal_module(x2)
            x = torch.cat((x1, x2), 0)
        elif modal == 1:
            x = self.visible_module(x1)
        elif modal == 2:
            x = self.thermal_module(x2)

        # shared block
        if self.non_local == 'on':
            # 在base_resnet.base.layer1后面插入模块
            NL1_counter = 0
            if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1]
            for i in range(len(self.base_resnet.base.layer1)):
                x = self.base_resnet.base.layer1[i](x)
                if i == self.NL_1_idx[NL1_counter]:
                    _, C, H, W = x.shape
                    x = self.NL_1[NL1_counter](x)
                    NL1_counter += 1
                    
            # Layer 2
            NL2_counter = 0
            if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1]
            for i in range(len(self.base_resnet.base.layer2)):
                x = self.base_resnet.base.layer2[i](x)
                if i == self.NL_2_idx[NL2_counter]:
                    _, C, H, W = x.shape
                    x = self.NL_2[NL2_counter](x)
                    NL2_counter += 1
            # Layer 3
            NL3_counter = 0
            if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1]
            for i in range(len(self.base_resnet.base.layer3)):
                x = self.base_resnet.base.layer3[i](x)
                if i == self.NL_3_idx[NL3_counter]:
                    _, C, H, W = x.shape
                    x = self.NL_3[NL3_counter](x)
                    NL3_counter += 1
            # Layer 4
            NL4_counter = 0
            if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1]
            for i in range(len(self.base_resnet.base.layer4)):
                x = self.base_resnet.base.layer4[i](x)
                if i == self.NL_4_idx[NL4_counter]:
                    _, C, H, W = x.shape
                    x = self.NL_4[NL4_counter](x)
                    NL4_counter += 1
        else:
            x = self.base_resnet(x)
        if self.gm_pool  == 'on':
            b, c, h, w = x.shape
            x = x.view(b, c, -1)
            p = 3.0
            x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p)
        else:
            x_pool = self.avgpool(x)
            x_pool = x_pool.view(x_pool.size(0), x_pool.size(1))

        feat = self.bottleneck(x_pool)

        if self.training:
            return x_pool, self.classifier(feat)
        else:
            return self.l2norm(x_pool), self.l2norm(feat)

        注意代码第51-58行,它在base_resnet.base.layer1后面加入了一个NL1模块,此时,对于 self.base_resnet.base.layer2,它的输入是 x,而这个 x 是由 self.base_resnet.base.layer1 的循环处理过的结果。也就是说,x 是经过 layer1 处理的特征张量。

        具体来说,x 最初是输入到整个 ResNet 网络的原始输入。在第一个 for 循环中,经过 self.base_resnet.base.layer1[i](x) 的处理,x 发生了变化,变成了 layer1 中第 i 个基本块的输出。

        如果在这个循环中,if i == self.NL_1_idx[NL1_counter]: 的条件满足,那么就会在 x 上插入 NL_1[NL1_counter] 中的 Non_local 模块,进一步修改了 x

        所以,对于 self.base_resnet.base.layer2 来说,它的输入是经过 self.base_resnet.base.layer1 处理过,并且在部分位置经过了 Non_local 模块的修改的特征张量 x

二、对self.base_resnet.base.layer1/2的解释

   self.base_resnet.base.layer1 是指向 ResNet 网络中的第一个层的引用。在 PyTorch 中,ResNet 网络被组织成一系列的层,其中每个层包含多个基本块(basic block)。

         具体而言,ResNet 网络的结构如下:

  • layer1: 第一个阶段,包含多个基本块。
  • layer2: 第二个阶段,也包含多个基本块。
  • layer3: 第三个阶段,同样包含多个基本块。
  • layer4: 第四个阶段,仍然包含多个基本块。

     每个基本块由两个卷积层组成,同时包含批量归一化和激活函数。这些基本块通过残差连接(Residual Connection)的方式连接在一起,以帮助解决梯度消失的问题,使得更深的网络可以更容易地训练。

         因此,self.base_resnet.base.layer1 是指向 ResNet 网络中的第一个阶段(stage)的引用。这个阶段通常包含多个基本块,每个基本块又包含若干个卷积层、归一化层和激活函数。

       以resnet18为例:

   self.base_resnet.base.layer1 的内部结构通常包含四个基本块(basic block)。每个基本块由两个卷积层组成,具体结构如下:

  1. Basic Block 1:

    • 3x3 卷积层
    • 批量归一化(Batch Normalization)
    • ReLU 激活函数
    • 3x3 卷积层
    • 批量归一化
    • 残差连接(Residual Connection):通过跳跃连接将输入直接添加到卷积输出上
    • ReLU 激活函数
  2. Basic Block 2:

    • 同上
  3. Basic Block 3:

    • 同上
  4. Basic Block 4:

    • 同上

      这是 ResNet-18 中第一个阶段(layer1)的基本块结构。整个 layer1 就是这四个基本块的堆叠。这些基本块的设计目的是通过残差连接来缓解梯度消失问题,使得神经网络更容易训练。

     而进一步的,self.base_resnet.base.layer1[2] 指的是 ResNet 网络中第一个阶段(layer1)的第三个基本块(basic block)。在 ResNet 的 PyTorch 实现中,这是一个通过索引选择基本块的方式。

      在 ResNet-18 中,layer1 由四个基本块组成,索引从 0 到 3。因此self.base_resnet.base.layer1[2] 就是 layer1 中的第三个基本块。

  • 11
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值