Residual Attention Network网络56和92层的计算

论文:Residual Attention Network for Image Classification
论文地址

有两种结构

def attention56():
    return Attention([1, 1, 1])

def attention92():
    return Attention([1, 2, 3])

原文的结构
在这里插入图片描述

这个56和92怎么计算的呢?

源代码

"""residual attention network in pytorch



[1] Fei Wang, Mengqing Jiang, Chen Qian, Shuo Yang, Cheng Li, Honggang Zhang, Xiaogang Wang, Xiaoou Tang

    Residual Attention Network for Image Classification
    https://arxiv.org/abs/1704.06904
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

#"""The Attention Module is built by pre-activation Residual Unit [11] with the
#number of channels in each stage is the same as ResNet [10]."""

class PreActResidualUnit(nn.Module):
    """PreAct Residual Unit
    Args:
        in_channels: residual unit input channel number
        out_channels: residual unit output channel numebr
        stride: stride of residual unit when stride = 2, downsample the featuremap
    """

    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        bottleneck_channels = int(out_channels / 4)
        self.residual_function = nn.Sequential(
            #1x1 conv
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, bottleneck_channels, 1, stride),

            #3x3 conv
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1),

            #1x1 conv
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(bottleneck_channels, out_channels, 1)
        )

        self.shortcut = nn.Sequential()
        if stride != 2 or (in_channels != out_channels):
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride)

    def forward(self, x):

        res = self.residual_function(x)
        shortcut = self.shortcut(x)

        return res + shortcut

class AttentionModule1(nn.Module):

    def __init__(self, in_channels, out_channels, p=1, t=2, r=1):
        super().__init__()
        #"""The hyperparameter p denotes the number of preprocessing Residual
        #Units before splitting into trunk branch and mask branch. t denotes
        #the number of Residual Units in trunk branch. r denotes the number of
        #Residual Units between adjacent pooling layer in the mask branch."""
        assert in_channels == out_channels

        self.pre = self._make_residual(in_channels, out_channels, p)
        self.trunk = self._make_residual(in_channels, out_channels, t)
        self.soft_resdown1 = self._make_residual(in_channels, out_channels, r)
        self.soft_resdown2 = self._make_residual(in_channels, out_channels, r)
        self.soft_resdown3 = self._make_residual(in_channels, out_channels, r)
        self.soft_resdown4 = self._make_residual(in_channels, out_channels, r)

        self.soft_resup1 = self._make_residual(in_channels, out_channels, r)
        self.soft_resup2 = self._make_residual(in_channels, out_channels, r)
        self.soft_resup3 = self._make_residual(in_channels, out_channels, r)
        self.soft_resup4 = self._make_residual(in_channels, out_channels, r)

        self.shortcut_short = PreActResidualUnit(in_channels, out_channels, 1)
        self.shortcut_long = PreActResidualUnit(in_channels, out_channels, 1)

        self.sigmoid = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

        self.last = self._make_residual(in_channels, out_channels, p)

    def forward(self, x):
        ###We make the size of the smallest output map in each mask branch 7*7 to be consistent
        #with the smallest trunk output map size.
        ###Thus 3,2,1 max-pooling layers are used in mask branch with input size 56 * 56, 28 * 28, 14 * 14 respectively.
        x = self.pre(x)
        input_size = (x.size(2), x.size(3))

        x_t = self.trunk(x)

        #first downsample out 28
        x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x_s = self.soft_resdown1(x_s)

        #28 shortcut
        shape1 = (x_s.size(2), x_s.size(3))
        shortcut_long = self.shortcut_long(x_s)

        #seccond downsample out 14
        x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x_s = self.soft_resdown2(x_s)

        #14 shortcut
        shape2 = (x_s.size(2), x_s.size(3))
        shortcut_short = self.soft_resdown3(x_s)

        #third downsample out 7
        x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x_s = self.soft_resdown3(x_s)

        #mid
        x_s = self.soft_resdown4(x_s)
        x_s = self.soft_resup1(x_s)

        #first upsample out 14
        x_s = self.soft_resup2(x_s)
        x_s = F.interpolate(x_s, size=shape2)
        x_s += shortcut_short

        #second upsample out 28
        x_s = self.soft_resup3(x_s)
        x_s = F.interpolate(x_s, size=shape1)
        x_s += shortcut_long

        #thrid upsample out 54
        x_s = self.soft_resup4(x_s)
        x_s = F.interpolate(x_s, size=input_size)

        x_s = self.sigmoid(x_s)
        x = (1 + x_s) * x_t
        x = self.last(x)

        return x

    def _make_residual(self, in_channels, out_channels, p):

        layers = []
        for _ in range(p):
            layers.append(PreActResidualUnit(in_channels, out_channels, 1))

        return nn.Sequential(*layers)

class AttentionModule2(nn.Module):

    def __init__(self, in_channels, out_channels, p=1, t=2, r=1):
        super().__init__()
        #"""The hyperparameter p denotes the number of preprocessing Residual
        #Units before splitting into trunk branch and mask branch. t denotes
        #the number of Residual Units in trunk branch. r denotes the number of
        #Residual Units between adjacent pooling layer in the mask branch."""
        assert in_channels == out_channels

        self.pre = self._make_residual(in_channels, out_channels, p)
        self.trunk = self._make_residual(in_channels, out_channels, t)
        self.soft_resdown1 = self._make_residual(in_channels, out_channels, r)
        self.soft_resdown2 = self._make_residual(in_channels, out_channels, r)
        self.soft_resdown3 = self._make_residual(in_channels, out_channels, r)

        self.soft_resup1 = self._make_residual(in_channels, out_channels, r)
        self.soft_resup2 = self._make_residual(in_channels, out_channels, r)
        self.soft_resup3 = self._make_residual(in_channels, out_channels, r)

        self.shortcut = PreActResidualUnit(in_channels, out_channels, 1)

        self.sigmoid = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

        self.last = self._make_residual(in_channels, out_channels, p)

    def forward(self, x):
        x = self.pre(x)
        input_size = (x.size(2), x.size(3))

        x_t = self.trunk(x)

        #first downsample out 14
        x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x_s = self.soft_resdown1(x_s)

        #14 shortcut
        shape1 = (x_s.size(2), x_s.size(3))
        shortcut = self.shortcut(x_s)

        #seccond downsample out 7
        x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x_s = self.soft_resdown2(x_s)

        #mid
        x_s = self.soft_resdown3(x_s)
        x_s = self.soft_resup1(x_s)

        #first upsample out 14
        x_s = self.soft_resup2(x_s)
        x_s = F.interpolate(x_s, size=shape1)
        x_s += shortcut

        #second upsample out 28
        x_s = self.soft_resup3(x_s)
        x_s = F.interpolate(x_s, size=input_size)

        x_s = self.sigmoid(x_s)
        x = (1 + x_s) * x_t
        x = self.last(x)

        return x

    def _make_residual(self, in_channels, out_channels, p):

        layers = []
        for _ in range(p):
            layers.append(PreActResidualUnit(in_channels, out_channels, 1))

        return nn.Sequential(*layers)

class AttentionModule3(nn.Module):

    def __init__(self, in_channels, out_channels, p=1, t=2, r=1):
        super().__init__()

        assert in_channels == out_channels

        self.pre = self._make_residual(in_channels, out_channels, p)
        self.trunk = self._make_residual(in_channels, out_channels, t)
        self.soft_resdown1 = self._make_residual(in_channels, out_channels, r)
        self.soft_resdown2 = self._make_residual(in_channels, out_channels, r)

        self.soft_resup1 = self._make_residual(in_channels, out_channels, r)
        self.soft_resup2 = self._make_residual(in_channels, out_channels, r)

        self.shortcut = PreActResidualUnit(in_channels, out_channels, 1)

        self.sigmoid = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.Sigmoid()
        )

        self.last = self._make_residual(in_channels, out_channels, p)

    def forward(self, x):
        x = self.pre(x)
        input_size = (x.size(2), x.size(3))

        x_t = self.trunk(x)

        #first downsample out 14
        x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        x_s = self.soft_resdown1(x_s)

        #mid
        x_s = self.soft_resdown2(x_s)
        x_s = self.soft_resup1(x_s)

        #first upsample out 14
        x_s = self.soft_resup2(x_s)
        x_s = F.interpolate(x_s, size=input_size)

        x_s = self.sigmoid(x_s)
        x = (1 + x_s) * x_t
        x = self.last(x)

        return x

    def _make_residual(self, in_channels, out_channels, p):

        layers = []
        for _ in range(p):
            layers.append(PreActResidualUnit(in_channels, out_channels, 1))

        return nn.Sequential(*layers)

class Attention(nn.Module):
    """residual attention netowrk
    Args:
        block_num: attention module number for each stage
    """

    def __init__(self, block_num, class_num=100):

        super().__init__()
        self.pre_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1)
        self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2)
        self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3)
        self.stage4 = nn.Sequential(
            PreActResidualUnit(1024, 2048, 2),
            PreActResidualUnit(2048, 2048, 1),
            PreActResidualUnit(2048, 2048, 1)
        )
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(2048, 100)

    def forward(self, x):
        x = self.pre_conv(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.avg(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)

        return x

    def _make_stage(self, in_channels, out_channels, num, block):

        layers = []
        layers.append(PreActResidualUnit(in_channels, out_channels, 2))

        for _ in range(num):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

def attention56():
    return Attention([1, 1, 1])

def attention92():
    return Attention([1, 2, 3])


代码过于冗长

完全没有耐心看

就是论文的图画得太复杂,才不想看图

结果代码也更加复杂

怎么办?

打印网络如下:

Attention(
  (pre_conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (stage1): Sequential(
    (0): PreActResidualUnit(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2))
        (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(2, 2))
    )
    (1): AttentionModule1(
      (pre): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (trunk): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown1): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown2): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown3): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown4): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup1): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup2): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup3): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup4): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (shortcut_short): PreActResidualUnit(
        (residual_function): Sequential(
          (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut_long): PreActResidualUnit(
        (residual_function): Sequential(
          (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        )
        (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (sigmoid): Sequential(
        (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (last): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (stage2): Sequential(
    (0): PreActResidualUnit(
      (residual_function): Sequential(
        (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(2, 2))
        (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
    )
    (1): AttentionModule2(
      (pre): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (trunk): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown1): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown2): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown3): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup1): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup2): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup3): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (shortcut): PreActResidualUnit(
        (residual_function): Sequential(
          (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
          (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
        )
        (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (sigmoid): Sequential(
        (0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
      )
      (last): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (stage3): Sequential(
    (0): PreActResidualUnit(
      (residual_function): Sequential(
        (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(2, 2))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))
    )
    (1): AttentionModule3(
      (pre): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (trunk): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
        (1): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown1): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resdown2): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup1): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (soft_resup2): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (shortcut): PreActResidualUnit(
        (residual_function): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
          (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
        (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
      )
      (sigmoid): Sequential(
        (0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
      )
      (last): Sequential(
        (0): PreActResidualUnit(
          (residual_function): Sequential(
            (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
            (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
          )
          (shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
        )
      )
    )
  )
  (stage4): Sequential(
    (0): PreActResidualUnit(
      (residual_function): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(2, 2))
        (1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))
    )
    (1): PreActResidualUnit(
      (residual_function): Sequential(
        (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut): Conv2d(2048, 2048, kernel_size=(1, 1), stride=(1, 1))
    )
    (2): PreActResidualUnit(
      (residual_function): Sequential(
        (0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
        (1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
      )
      (shortcut): Conv2d(2048, 2048, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (avg): AdaptiveAvgPool2d(output_size=1)
  (linear): Linear(in_features=2048, out_features=100, bias=True)
)

查找竟然出现了167次Conv2d,44次shortcut,不知道该怎么看。
于是这样看,已经给出了两种结构
92-56 = 36
3个模块的层数之和是36

再看怎么堆叠的

        self.pre_conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1)
        self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2)
        self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3)
        self.stage4 = nn.Sequential(
            PreActResidualUnit(1024, 2048, 2),
            PreActResidualUnit(2048, 2048, 1),
            PreActResidualUnit(2048, 2048, 1)
        )
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(2048, 100)

1个pre_conv 和 4个stage以及linear层

再看stage

        layers = []
        layers.append(PreActResidualUnit(in_channels, out_channels, 2))

        for _ in range(num):
            layers.append(block(out_channels, out_channels))

这是由PreActResidualUnit + Block组成
所以算出1个Block应该是12层,而PreActResidualUnit有3个卷积层
这样就可以把数字加起来了
1 个卷积,初始化------------------------------------1层
3个stage,(12+3)* 3 ---------------------------45层
stage4,3个PreActResidualUnit-----------------9层
最后全连接层,---------------------------------------1层
attention 56完成

而另一种结构
这样就可以把数字加起来了
1 个卷积,初始化-------------------------------------1层
3个stage,15 +(12 * 2 + 3) + (12 * 3 + 3) --81层
stage4,3个PreActResidualUnit模块------------9层
最后全连接层------------------------------------------1层
刚好是92层

如果只是巧合算出,请大神指出错误。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值