ReXNet:消除表达瓶颈,进来唠唠网络设计那些事

ReXNet

CVPR2021的文章,代码不长,简单复现了一下,重点应该是对网络设计模型的一些思考,进来咱们唠唠网络模型一些思想

  • 基于原Pytorch,更少改动

  • 本文更多是吸收网络设计一些有用的思想

  • 论文最后部分给出他们为什么精度这么好,原来是各种trick都用上了,只能说,不讲武德

从MoblieNet V2开始谈起

首先我们看看这张图,什么意思呢,用一句话概括就是,当输入的维度低时候,经过ReLU这样的非线性函数后,会损失很多的信息,当输入的维度足够高,经过非线性函数后损失的信息更少

基于此,MoblieNet V2提出Linear bottleneck和Inverted residuals

Linear bottleneck

输入通过 1x1 conv 扩大通道,然后进行 Depthwish conv,然后通过1x1 conv 降维

这三个操作中,前面两个之后会接ReLU6,最后一个因为输出维度低,不接ReLU6

Inverted residuals

一个是中间瘦两边胖,一个是中间胖两边瘦

所以综上所述,基本block如下,分为 stride=1 的block和 stride=2 的下采样block,注意一点,最后的 1x1 没有接ReLU6

模型结构图

而ReXNet,就是基于上面网络设计不足之处进行再一次改进

ReXNet有趣的一些思想

ReXNet主要思想是消除表征瓶颈,这里直接放总结,也是作者提出的设计原则

  • 扩展输入channel大小

    通过Softmax函数的瓶颈,作者联想到层瓶颈问题,即当输入维度小于输出维度时候,输入的低秩性无法表示高秩空间,比如你输入2个维度,输出10维度分类,2个维度的数据不好表示10维度的分类

    作者认为输入维度和输出维度应满足以下不等式,使得层瓶颈影响更小

  • 适当的激活函数——Swish-1

    建议看看这篇文章:https://arxiv.org/abs/1710.05941

    实验证明Swish函数具有更高的秩,可以提升数据的秩,让输入的秩更能接近输出的秩,从而减小层瓶颈

  • 多个扩展层,通道逐步递进

    扩展层是输出channel大于输入channel的层,多个扩展层可以防止输入维度和输出维度秩相差太大,逐步推进更好减少层表达瓶颈

对比一下上面MoblieNet V2网络结构,你发现什么不同了吗

  • 多个扩展层
  • 输入通道和输出通道递层增加,相差不大
  • 引入Swish-1函数,在block里第一个1x1后

完整代码

1. 模型建立

import paddle
import paddle.nn as nn
from math import ceil

print(paddle.__version__)
2.0.1
def ConvBNAct(out, in_channels, channels, kernel=1, stride=1, pad=0,
              num_group=1, active=True, relu6=False):
    out.append(nn.Conv2D(in_channels, channels, kernel,
                         stride, pad, groups=num_group, bias_attr=False))
    out.append(nn.BatchNorm2D(channels))
    if active:
        out.append(nn.ReLU6() if relu6 else nn.ReLU())


def ConvBNSwish(out, in_channels, channels, kernel=1, stride=1, pad=0, num_group=1):
    out.append(nn.Conv2D(in_channels, channels, kernel,
                         stride, pad, groups=num_group, bias_attr=False))
    out.append(nn.BatchNorm2D(channels))
    out.append(nn.Swish())


class SE(nn.Layer):
    def __init__(self, in_channels, channels, se_ratio=12):
        super(SE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)
        self.fc = nn.Sequential(
            nn.Conv2D(in_channels, channels // se_ratio, kernel_size=1, padding=0),
            nn.BatchNorm2D(channels // se_ratio),
            nn.ReLU(),
            nn.Conv2D(channels // se_ratio, channels, kernel_size=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y


class LinearBottleneck(nn.Layer):
    def __init__(self, in_channels, channels, t, stride, use_se=True, se_ratio=12,
                 **kwargs):
        super(LinearBottleneck, self).__init__(**kwargs)
        self.use_shortcut = stride == 1 and in_channels <= channels
        self.in_channels = in_channels
        self.out_channels = channels

        out = []
        if t != 1:
            dw_channels = in_channels * t
            ConvBNSwish(out, in_channels=in_channels, channels=dw_channels)
        else:
            dw_channels = in_channels

        ConvBNAct(out, in_channels=dw_channels, channels=dw_channels, kernel=3, stride=stride, pad=1,
                  num_group=dw_channels, active=False)

        if use_se:
            out.append(SE(dw_channels, dw_channels, se_ratio))

        out.append(nn.ReLU6())
        ConvBNAct(out, in_channels=dw_channels, channels=channels, active=False, relu6=True)
        self.out = nn.Sequential(*out)

    def forward(self, x):
        out = self.out(x)
        if self.use_shortcut:
            out[:, 0:self.in_channels] += x

        return out


class ReXNetV1(nn.Layer):
    def __init__(self, input_ch=16, final_ch=180, width_mult=1.0, depth_mult=1.0, classes=1000,
                 use_se=True,
                 se_ratio=12,
                 dropout_ratio=0.2,
                 bn_momentum=0.9):
        super(ReXNetV1, self).__init__()

        layers = [1, 2, 2, 3, 3, 5]
        strides = [1, 2, 2, 2, 1, 2]
        use_ses = [False, False, True, True, True, True]

        layers = [ceil(element * depth_mult) for element in layers]
        strides = sum([[element] + [1] * (layers[idx] - 1)
                       for idx, element in enumerate(strides)], [])
        if use_se:
            use_ses = sum([[element] * layers[idx] for idx, element in enumerate(use_ses)], [])
        else:
            use_ses = [False] * sum(layers[:])
        ts = [1] * layers[0] + [6] * sum(layers[1:])

        self.depth = sum(layers[:]) * 3
        stem_channel = 32 / width_mult if width_mult < 1.0 else 32
        inplanes = input_ch / width_mult if width_mult < 1.0 else input_ch

        features = []
        in_channels_group = []
        channels_group = []


        for i in range(self.depth // 3):
            if i == 0:
                in_channels_group.append(int(round(stem_channel * width_mult)))
                channels_group.append(int(round(inplanes * width_mult)))
            else:
                in_channels_group.append(int(round(inplanes * width_mult)))
                inplanes += final_ch / (self.depth // 3 * 1.0)
                channels_group.append(int(round(inplanes * width_mult)))

        ConvBNSwish(features, 3, int(round(stem_channel * width_mult)), kernel=3, stride=2, pad=1)

        for block_idx, (in_c, c, t, s, se) in enumerate(zip(in_channels_group, channels_group, ts, strides, use_ses)):
            features.append(LinearBottleneck(in_channels=in_c,
                                             channels=c,
                                             t=t,
                                             stride=s,
                                             use_se=se, se_ratio=se_ratio))

        pen_channels = int(1280 * width_mult)
        ConvBNSwish(features, c, pen_channels)

        features.append(nn.AdaptiveAvgPool2D(1))
        self.features = nn.Sequential(*features)
        self.output = nn.Sequential(
            nn.Dropout(dropout_ratio),
            nn.Conv2D(pen_channels, classes, 1, bias_attr=True))

    def forward(self, x):
        x = self.features(x)
        x = self.output(x).squeeze()
        return x
    
rexnet=ReXNetV1(classes=10)
rexnet
ReXNetV1(
  (features): Sequential(
    (0): Conv2D(3, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)
    (1): BatchNorm2D(num_features=32, momentum=0.9, epsilon=1e-05)
    (2): Swish()
    (3): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(32, 32, kernel_size=[3, 3], padding=1, groups=32, data_format=NCHW)
        (1): BatchNorm2D(num_features=32, momentum=0.9, epsilon=1e-05)
        (2): ReLU6()
        (3): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)
        (4): BatchNorm2D(num_features=16, momentum=0.9, epsilon=1e-05)
      )
    )
    (4): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=96, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(96, 96, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=96, data_format=NCHW)
        (4): BatchNorm2D(num_features=96, momentum=0.9, epsilon=1e-05)
        (5): ReLU6()
        (6): Conv2D(96, 27, kernel_size=[1, 1], data_format=NCHW)
        (7): BatchNorm2D(num_features=27, momentum=0.9, epsilon=1e-05)
      )
    )
    (5): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(27, 162, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=162, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(162, 162, kernel_size=[3, 3], padding=1, groups=162, data_format=NCHW)
        (4): BatchNorm2D(num_features=162, momentum=0.9, epsilon=1e-05)
        (5): ReLU6()
        (6): Conv2D(162, 38, kernel_size=[1, 1], data_format=NCHW)
        (7): BatchNorm2D(num_features=38, momentum=0.9, epsilon=1e-05)
      )
    )
    (6): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(38, 228, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=228, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(228, 228, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=228, data_format=NCHW)
        (4): BatchNorm2D(num_features=228, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(228, 19, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=19, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(19, 228, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(228, 50, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=50, momentum=0.9, epsilon=1e-05)
      )
    )
    (7): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(50, 300, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=300, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(300, 300, kernel_size=[3, 3], padding=1, groups=300, data_format=NCHW)
        (4): BatchNorm2D(num_features=300, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(300, 25, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=25, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(25, 300, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(300, 61, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=61, momentum=0.9, epsilon=1e-05)
      )
    )
    (8): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(61, 366, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=366, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(366, 366, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=366, data_format=NCHW)
        (4): BatchNorm2D(num_features=366, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(366, 30, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=30, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(30, 366, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(366, 72, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=72, momentum=0.9, epsilon=1e-05)
      )
    )
    (9): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(72, 432, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=432, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(432, 432, kernel_size=[3, 3], padding=1, groups=432, data_format=NCHW)
        (4): BatchNorm2D(num_features=432, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(432, 36, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=36, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(36, 432, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(432, 84, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=84, momentum=0.9, epsilon=1e-05)
      )
    )
    (10): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(84, 504, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=504, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(504, 504, kernel_size=[3, 3], padding=1, groups=504, data_format=NCHW)
        (4): BatchNorm2D(num_features=504, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(504, 42, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=42, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(42, 504, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(504, 95, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=95, momentum=0.9, epsilon=1e-05)
      )
    )
    (11): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(95, 570, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=570, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(570, 570, kernel_size=[3, 3], padding=1, groups=570, data_format=NCHW)
        (4): BatchNorm2D(num_features=570, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(570, 47, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=47, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(47, 570, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(570, 106, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=106, momentum=0.9, epsilon=1e-05)
      )
    )
    (12): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(106, 636, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=636, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(636, 636, kernel_size=[3, 3], padding=1, groups=636, data_format=NCHW)
        (4): BatchNorm2D(num_features=636, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(636, 53, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=53, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(53, 636, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(636, 117, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=117, momentum=0.9, epsilon=1e-05)
      )
    )
    (13): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(117, 702, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=702, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(702, 702, kernel_size=[3, 3], padding=1, groups=702, data_format=NCHW)
        (4): BatchNorm2D(num_features=702, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(702, 58, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=58, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(58, 702, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(702, 128, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=128, momentum=0.9, epsilon=1e-05)
      )
    )
    (14): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(128, 768, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=768, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(768, 768, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=768, data_format=NCHW)
        (4): BatchNorm2D(num_features=768, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(768, 64, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=64, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(64, 768, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(768, 140, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=140, momentum=0.9, epsilon=1e-05)
      )
    )
    (15): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(140, 840, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=840, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(840, 840, kernel_size=[3, 3], padding=1, groups=840, data_format=NCHW)
        (4): BatchNorm2D(num_features=840, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(840, 70, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=70, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(70, 840, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(840, 151, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=151, momentum=0.9, epsilon=1e-05)
      )
    )
    (16): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(151, 906, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=906, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(906, 906, kernel_size=[3, 3], padding=1, groups=906, data_format=NCHW)
        (4): BatchNorm2D(num_features=906, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(906, 75, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=75, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(75, 906, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(906, 162, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=162, momentum=0.9, epsilon=1e-05)
      )
    )
    (17): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(162, 972, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=972, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(972, 972, kernel_size=[3, 3], padding=1, groups=972, data_format=NCHW)
        (4): BatchNorm2D(num_features=972, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(972, 81, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=81, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(81, 972, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(972, 174, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=174, momentum=0.9, epsilon=1e-05)
      )
    )
    (18): LinearBottleneck(
      (out): Sequential(
        (0): Conv2D(174, 1044, kernel_size=[1, 1], data_format=NCHW)
        (1): BatchNorm2D(num_features=1044, momentum=0.9, epsilon=1e-05)
        (2): Swish()
        (3): Conv2D(1044, 1044, kernel_size=[3, 3], padding=1, groups=1044, data_format=NCHW)
        (4): BatchNorm2D(num_features=1044, momentum=0.9, epsilon=1e-05)
        (5): SE(
          (avg_pool): AdaptiveAvgPool2D(output_size=1)
          (fc): Sequential(
            (0): Conv2D(1044, 87, kernel_size=[1, 1], data_format=NCHW)
            (1): BatchNorm2D(num_features=87, momentum=0.9, epsilon=1e-05)
            (2): ReLU()
            (3): Conv2D(87, 1044, kernel_size=[1, 1], data_format=NCHW)
            (4): Sigmoid()
          )
        )
        (6): ReLU6()
        (7): Conv2D(1044, 185, kernel_size=[1, 1], data_format=NCHW)
        (8): BatchNorm2D(num_features=185, momentum=0.9, epsilon=1e-05)
      )
    )
    (19): Conv2D(185, 1280, kernel_size=[1, 1], data_format=NCHW)
    (20): BatchNorm2D(num_features=1280, momentum=0.9, epsilon=1e-05)
    (21): Swish()
    (22): AdaptiveAvgPool2D(output_size=1)
  )
  (output): Sequential(
    (0): Dropout(p=0.2, axis=None, mode=upscale_in_train)
    (1): Conv2D(1280, 10, kernel_size=[1, 1], data_format=NCHW)
  )
)

2. 数据准备

采用Cifar10数据集,无过多的数据增强

import paddle.vision.transforms as T
from paddle.vision.datasets import Cifar10

#数据准备
transform = T.Compose([
    T.Resize(size=(224,224)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],data_format='HWC'),
    T.ToTensor()
])

train_dataset = Cifar10(mode='train', transform=transform)
val_dataset = Cifar10(mode='test',  transform=transform)

model=paddle.Model(rexnet)
model.summary((1,3,224,224))
--------------------------------------------------------------------------------
    Layer (type)         Input Shape          Output Shape         Param #    
================================================================================
      Conv2D-1        [[1, 3, 224, 224]]   [1, 32, 112, 112]         864      
   BatchNorm2D-1     [[1, 32, 112, 112]]   [1, 32, 112, 112]         128      
      Swish-1        [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
      Conv2D-2       [[1, 32, 112, 112]]   [1, 32, 112, 112]         288      
   BatchNorm2D-2     [[1, 32, 112, 112]]   [1, 32, 112, 112]         128      
      ReLU6-1        [[1, 32, 112, 112]]   [1, 32, 112, 112]          0       
      Conv2D-3       [[1, 32, 112, 112]]   [1, 16, 112, 112]         512      
   BatchNorm2D-3     [[1, 16, 112, 112]]   [1, 16, 112, 112]         64       
 LinearBottleneck-1  [[1, 32, 112, 112]]   [1, 16, 112, 112]          0       
      Conv2D-4       [[1, 16, 112, 112]]   [1, 96, 112, 112]        1,536     
   BatchNorm2D-4     [[1, 96, 112, 112]]   [1, 96, 112, 112]         384      
      Swish-2        [[1, 96, 112, 112]]   [1, 96, 112, 112]          0       
      Conv2D-5       [[1, 96, 112, 112]]    [1, 96, 56, 56]          864      
   BatchNorm2D-5      [[1, 96, 56, 56]]     [1, 96, 56, 56]          384      
      ReLU6-2         [[1, 96, 56, 56]]     [1, 96, 56, 56]           0       
      Conv2D-6        [[1, 96, 56, 56]]     [1, 27, 56, 56]         2,592     
   BatchNorm2D-6      [[1, 27, 56, 56]]     [1, 27, 56, 56]          108      
 LinearBottleneck-2  [[1, 16, 112, 112]]    [1, 27, 56, 56]           0       
      Conv2D-7        [[1, 27, 56, 56]]     [1, 162, 56, 56]        4,374     
   BatchNorm2D-7      [[1, 162, 56, 56]]    [1, 162, 56, 56]         648      
      Swish-3         [[1, 162, 56, 56]]    [1, 162, 56, 56]          0       
      Conv2D-8        [[1, 162, 56, 56]]    [1, 162, 56, 56]        1,458     
   BatchNorm2D-8      [[1, 162, 56, 56]]    [1, 162, 56, 56]         648      
      ReLU6-3         [[1, 162, 56, 56]]    [1, 162, 56, 56]          0       
      Conv2D-9        [[1, 162, 56, 56]]    [1, 38, 56, 56]         6,156     
   BatchNorm2D-9      [[1, 38, 56, 56]]     [1, 38, 56, 56]          152      
 LinearBottleneck-3   [[1, 27, 56, 56]]     [1, 38, 56, 56]           0       
     Conv2D-10        [[1, 38, 56, 56]]     [1, 228, 56, 56]        8,664     
   BatchNorm2D-10     [[1, 228, 56, 56]]    [1, 228, 56, 56]         912      
      Swish-4         [[1, 228, 56, 56]]    [1, 228, 56, 56]          0       
     Conv2D-11        [[1, 228, 56, 56]]    [1, 228, 28, 28]        2,052     
   BatchNorm2D-11     [[1, 228, 28, 28]]    [1, 228, 28, 28]         912      
AdaptiveAvgPool2D-1   [[1, 228, 28, 28]]     [1, 228, 1, 1]           0       
     Conv2D-12         [[1, 228, 1, 1]]      [1, 19, 1, 1]          4,351     
   BatchNorm2D-12      [[1, 19, 1, 1]]       [1, 19, 1, 1]           76       
       ReLU-1          [[1, 19, 1, 1]]       [1, 19, 1, 1]            0       
     Conv2D-13         [[1, 19, 1, 1]]       [1, 228, 1, 1]         4,560     
     Sigmoid-1         [[1, 228, 1, 1]]      [1, 228, 1, 1]           0       
        SE-1          [[1, 228, 28, 28]]    [1, 228, 28, 28]          0       
      ReLU6-4         [[1, 228, 28, 28]]    [1, 228, 28, 28]          0       
     Conv2D-14        [[1, 228, 28, 28]]    [1, 50, 28, 28]        11,400     
   BatchNorm2D-13     [[1, 50, 28, 28]]     [1, 50, 28, 28]          200      
 LinearBottleneck-4   [[1, 38, 56, 56]]     [1, 50, 28, 28]           0       
     Conv2D-15        [[1, 50, 28, 28]]     [1, 300, 28, 28]       15,000     
   BatchNorm2D-14     [[1, 300, 28, 28]]    [1, 300, 28, 28]        1,200     
      Swish-5         [[1, 300, 28, 28]]    [1, 300, 28, 28]          0       
     Conv2D-16        [[1, 300, 28, 28]]    [1, 300, 28, 28]        2,700     
   BatchNorm2D-15     [[1, 300, 28, 28]]    [1, 300, 28, 28]        1,200     
AdaptiveAvgPool2D-2   [[1, 300, 28, 28]]     [1, 300, 1, 1]           0       
     Conv2D-17         [[1, 300, 1, 1]]      [1, 25, 1, 1]          7,525     
   BatchNorm2D-16      [[1, 25, 1, 1]]       [1, 25, 1, 1]           100      
       ReLU-2          [[1, 25, 1, 1]]       [1, 25, 1, 1]            0       
     Conv2D-18         [[1, 25, 1, 1]]       [1, 300, 1, 1]         7,800     
     Sigmoid-2         [[1, 300, 1, 1]]      [1, 300, 1, 1]           0       
        SE-2          [[1, 300, 28, 28]]    [1, 300, 28, 28]          0       
      ReLU6-5         [[1, 300, 28, 28]]    [1, 300, 28, 28]          0       
     Conv2D-19        [[1, 300, 28, 28]]    [1, 61, 28, 28]        18,300     
   BatchNorm2D-17     [[1, 61, 28, 28]]     [1, 61, 28, 28]          244      
 LinearBottleneck-5   [[1, 50, 28, 28]]     [1, 61, 28, 28]           0       
     Conv2D-20        [[1, 61, 28, 28]]     [1, 366, 28, 28]       22,326     
   BatchNorm2D-18     [[1, 366, 28, 28]]    [1, 366, 28, 28]        1,464     
      Swish-6         [[1, 366, 28, 28]]    [1, 366, 28, 28]          0       
     Conv2D-21        [[1, 366, 28, 28]]    [1, 366, 14, 14]        3,294     
   BatchNorm2D-19     [[1, 366, 14, 14]]    [1, 366, 14, 14]        1,464     
AdaptiveAvgPool2D-3   [[1, 366, 14, 14]]     [1, 366, 1, 1]           0       
     Conv2D-22         [[1, 366, 1, 1]]      [1, 30, 1, 1]         11,010     
   BatchNorm2D-20      [[1, 30, 1, 1]]       [1, 30, 1, 1]           120      
       ReLU-3          [[1, 30, 1, 1]]       [1, 30, 1, 1]            0       
     Conv2D-23         [[1, 30, 1, 1]]       [1, 366, 1, 1]        11,346     
     Sigmoid-3         [[1, 366, 1, 1]]      [1, 366, 1, 1]           0       
        SE-3          [[1, 366, 14, 14]]    [1, 366, 14, 14]          0       
      ReLU6-6         [[1, 366, 14, 14]]    [1, 366, 14, 14]          0       
     Conv2D-24        [[1, 366, 14, 14]]    [1, 72, 14, 14]        26,352     
   BatchNorm2D-21     [[1, 72, 14, 14]]     [1, 72, 14, 14]          288      
 LinearBottleneck-6   [[1, 61, 28, 28]]     [1, 72, 14, 14]           0       
     Conv2D-25        [[1, 72, 14, 14]]     [1, 432, 14, 14]       31,104     
   BatchNorm2D-22     [[1, 432, 14, 14]]    [1, 432, 14, 14]        1,728     
      Swish-7         [[1, 432, 14, 14]]    [1, 432, 14, 14]          0       
     Conv2D-26        [[1, 432, 14, 14]]    [1, 432, 14, 14]        3,888     
   BatchNorm2D-23     [[1, 432, 14, 14]]    [1, 432, 14, 14]        1,728     
AdaptiveAvgPool2D-4   [[1, 432, 14, 14]]     [1, 432, 1, 1]           0       
     Conv2D-27         [[1, 432, 1, 1]]      [1, 36, 1, 1]         15,588     
   BatchNorm2D-24      [[1, 36, 1, 1]]       [1, 36, 1, 1]           144      
       ReLU-4          [[1, 36, 1, 1]]       [1, 36, 1, 1]            0       
     Conv2D-28         [[1, 36, 1, 1]]       [1, 432, 1, 1]        15,984     
     Sigmoid-4         [[1, 432, 1, 1]]      [1, 432, 1, 1]           0       
        SE-4          [[1, 432, 14, 14]]    [1, 432, 14, 14]          0       
      ReLU6-7         [[1, 432, 14, 14]]    [1, 432, 14, 14]          0       
     Conv2D-29        [[1, 432, 14, 14]]    [1, 84, 14, 14]        36,288     
   BatchNorm2D-25     [[1, 84, 14, 14]]     [1, 84, 14, 14]          336      
 LinearBottleneck-7   [[1, 72, 14, 14]]     [1, 84, 14, 14]           0       
     Conv2D-30        [[1, 84, 14, 14]]     [1, 504, 14, 14]       42,336     
   BatchNorm2D-26     [[1, 504, 14, 14]]    [1, 504, 14, 14]        2,016     
      Swish-8         [[1, 504, 14, 14]]    [1, 504, 14, 14]          0       
     Conv2D-31        [[1, 504, 14, 14]]    [1, 504, 14, 14]        4,536     
   BatchNorm2D-27     [[1, 504, 14, 14]]    [1, 504, 14, 14]        2,016     
AdaptiveAvgPool2D-5   [[1, 504, 14, 14]]     [1, 504, 1, 1]           0       
     Conv2D-32         [[1, 504, 1, 1]]      [1, 42, 1, 1]         21,210     
   BatchNorm2D-28      [[1, 42, 1, 1]]       [1, 42, 1, 1]           168      
       ReLU-5          [[1, 42, 1, 1]]       [1, 42, 1, 1]            0       
     Conv2D-33         [[1, 42, 1, 1]]       [1, 504, 1, 1]        21,672     
     Sigmoid-5         [[1, 504, 1, 1]]      [1, 504, 1, 1]           0       
        SE-5          [[1, 504, 14, 14]]    [1, 504, 14, 14]          0       
      ReLU6-8         [[1, 504, 14, 14]]    [1, 504, 14, 14]          0       
     Conv2D-34        [[1, 504, 14, 14]]    [1, 95, 14, 14]        47,880     
   BatchNorm2D-29     [[1, 95, 14, 14]]     [1, 95, 14, 14]          380      
 LinearBottleneck-8   [[1, 84, 14, 14]]     [1, 95, 14, 14]           0       
     Conv2D-35        [[1, 95, 14, 14]]     [1, 570, 14, 14]       54,150     
   BatchNorm2D-30     [[1, 570, 14, 14]]    [1, 570, 14, 14]        2,280     
      Swish-9         [[1, 570, 14, 14]]    [1, 570, 14, 14]          0       
     Conv2D-36        [[1, 570, 14, 14]]    [1, 570, 14, 14]        5,130     
   BatchNorm2D-31     [[1, 570, 14, 14]]    [1, 570, 14, 14]        2,280     
AdaptiveAvgPool2D-6   [[1, 570, 14, 14]]     [1, 570, 1, 1]           0       
     Conv2D-37         [[1, 570, 1, 1]]      [1, 47, 1, 1]         26,837     
   BatchNorm2D-32      [[1, 47, 1, 1]]       [1, 47, 1, 1]           188      
       ReLU-6          [[1, 47, 1, 1]]       [1, 47, 1, 1]            0       
     Conv2D-38         [[1, 47, 1, 1]]       [1, 570, 1, 1]        27,360     
     Sigmoid-6         [[1, 570, 1, 1]]      [1, 570, 1, 1]           0       
        SE-6          [[1, 570, 14, 14]]    [1, 570, 14, 14]          0       
      ReLU6-9         [[1, 570, 14, 14]]    [1, 570, 14, 14]          0       
     Conv2D-39        [[1, 570, 14, 14]]    [1, 106, 14, 14]       60,420     
   BatchNorm2D-33     [[1, 106, 14, 14]]    [1, 106, 14, 14]         424      
 LinearBottleneck-9   [[1, 95, 14, 14]]     [1, 106, 14, 14]          0       
     Conv2D-40        [[1, 106, 14, 14]]    [1, 636, 14, 14]       67,416     
   BatchNorm2D-34     [[1, 636, 14, 14]]    [1, 636, 14, 14]        2,544     
      Swish-10        [[1, 636, 14, 14]]    [1, 636, 14, 14]          0       
     Conv2D-41        [[1, 636, 14, 14]]    [1, 636, 14, 14]        5,724     
   BatchNorm2D-35     [[1, 636, 14, 14]]    [1, 636, 14, 14]        2,544     
AdaptiveAvgPool2D-7   [[1, 636, 14, 14]]     [1, 636, 1, 1]           0       
     Conv2D-42         [[1, 636, 1, 1]]      [1, 53, 1, 1]         33,761     
   BatchNorm2D-36      [[1, 53, 1, 1]]       [1, 53, 1, 1]           212      
       ReLU-7          [[1, 53, 1, 1]]       [1, 53, 1, 1]            0       
     Conv2D-43         [[1, 53, 1, 1]]       [1, 636, 1, 1]        34,344     
     Sigmoid-7         [[1, 636, 1, 1]]      [1, 636, 1, 1]           0       
        SE-7          [[1, 636, 14, 14]]    [1, 636, 14, 14]          0       
      ReLU6-10        [[1, 636, 14, 14]]    [1, 636, 14, 14]          0       
     Conv2D-44        [[1, 636, 14, 14]]    [1, 117, 14, 14]       74,412     
   BatchNorm2D-37     [[1, 117, 14, 14]]    [1, 117, 14, 14]         468      
LinearBottleneck-10   [[1, 106, 14, 14]]    [1, 117, 14, 14]          0       
     Conv2D-45        [[1, 117, 14, 14]]    [1, 702, 14, 14]       82,134     
   BatchNorm2D-38     [[1, 702, 14, 14]]    [1, 702, 14, 14]        2,808     
      Swish-11        [[1, 702, 14, 14]]    [1, 702, 14, 14]          0       
     Conv2D-46        [[1, 702, 14, 14]]    [1, 702, 14, 14]        6,318     
   BatchNorm2D-39     [[1, 702, 14, 14]]    [1, 702, 14, 14]        2,808     
AdaptiveAvgPool2D-8   [[1, 702, 14, 14]]     [1, 702, 1, 1]           0       
     Conv2D-47         [[1, 702, 1, 1]]      [1, 58, 1, 1]         40,774     
   BatchNorm2D-40      [[1, 58, 1, 1]]       [1, 58, 1, 1]           232      
       ReLU-8          [[1, 58, 1, 1]]       [1, 58, 1, 1]            0       
     Conv2D-48         [[1, 58, 1, 1]]       [1, 702, 1, 1]        41,418     
     Sigmoid-8         [[1, 702, 1, 1]]      [1, 702, 1, 1]           0       
        SE-8          [[1, 702, 14, 14]]    [1, 702, 14, 14]          0       
      ReLU6-11        [[1, 702, 14, 14]]    [1, 702, 14, 14]          0       
     Conv2D-49        [[1, 702, 14, 14]]    [1, 128, 14, 14]       89,856     
   BatchNorm2D-41     [[1, 128, 14, 14]]    [1, 128, 14, 14]         512      
LinearBottleneck-11   [[1, 117, 14, 14]]    [1, 128, 14, 14]          0       
     Conv2D-50        [[1, 128, 14, 14]]    [1, 768, 14, 14]       98,304     
   BatchNorm2D-42     [[1, 768, 14, 14]]    [1, 768, 14, 14]        3,072     
      Swish-12        [[1, 768, 14, 14]]    [1, 768, 14, 14]          0       
     Conv2D-51        [[1, 768, 14, 14]]     [1, 768, 7, 7]         6,912     
   BatchNorm2D-43      [[1, 768, 7, 7]]      [1, 768, 7, 7]         3,072     
AdaptiveAvgPool2D-9    [[1, 768, 7, 7]]      [1, 768, 1, 1]           0       
     Conv2D-52         [[1, 768, 1, 1]]      [1, 64, 1, 1]         49,216     
   BatchNorm2D-44      [[1, 64, 1, 1]]       [1, 64, 1, 1]           256      
       ReLU-9          [[1, 64, 1, 1]]       [1, 64, 1, 1]            0       
     Conv2D-53         [[1, 64, 1, 1]]       [1, 768, 1, 1]        49,920     
     Sigmoid-9         [[1, 768, 1, 1]]      [1, 768, 1, 1]           0       
        SE-9           [[1, 768, 7, 7]]      [1, 768, 7, 7]           0       
      ReLU6-12         [[1, 768, 7, 7]]      [1, 768, 7, 7]           0       
     Conv2D-54         [[1, 768, 7, 7]]      [1, 140, 7, 7]        107,520    
   BatchNorm2D-45      [[1, 140, 7, 7]]      [1, 140, 7, 7]          560      
LinearBottleneck-12   [[1, 128, 14, 14]]     [1, 140, 7, 7]           0       
     Conv2D-55         [[1, 140, 7, 7]]      [1, 840, 7, 7]        117,600    
   BatchNorm2D-46      [[1, 840, 7, 7]]      [1, 840, 7, 7]         3,360     
      Swish-13         [[1, 840, 7, 7]]      [1, 840, 7, 7]           0       
     Conv2D-56         [[1, 840, 7, 7]]      [1, 840, 7, 7]         7,560     
   BatchNorm2D-47      [[1, 840, 7, 7]]      [1, 840, 7, 7]         3,360     
AdaptiveAvgPool2D-10   [[1, 840, 7, 7]]      [1, 840, 1, 1]           0       
     Conv2D-57         [[1, 840, 1, 1]]      [1, 70, 1, 1]         58,870     
   BatchNorm2D-48      [[1, 70, 1, 1]]       [1, 70, 1, 1]           280      
      ReLU-10          [[1, 70, 1, 1]]       [1, 70, 1, 1]            0       
     Conv2D-58         [[1, 70, 1, 1]]       [1, 840, 1, 1]        59,640     
     Sigmoid-10        [[1, 840, 1, 1]]      [1, 840, 1, 1]           0       
       SE-10           [[1, 840, 7, 7]]      [1, 840, 7, 7]           0       
      ReLU6-13         [[1, 840, 7, 7]]      [1, 840, 7, 7]           0       
     Conv2D-59         [[1, 840, 7, 7]]      [1, 151, 7, 7]        126,840    
   BatchNorm2D-49      [[1, 151, 7, 7]]      [1, 151, 7, 7]          604      
LinearBottleneck-13    [[1, 140, 7, 7]]      [1, 151, 7, 7]           0       
     Conv2D-60         [[1, 151, 7, 7]]      [1, 906, 7, 7]        136,806    
   BatchNorm2D-50      [[1, 906, 7, 7]]      [1, 906, 7, 7]         3,624     
      Swish-14         [[1, 906, 7, 7]]      [1, 906, 7, 7]           0       
     Conv2D-61         [[1, 906, 7, 7]]      [1, 906, 7, 7]         8,154     
   BatchNorm2D-51      [[1, 906, 7, 7]]      [1, 906, 7, 7]         3,624     
AdaptiveAvgPool2D-11   [[1, 906, 7, 7]]      [1, 906, 1, 1]           0       
     Conv2D-62         [[1, 906, 1, 1]]      [1, 75, 1, 1]         68,025     
   BatchNorm2D-52      [[1, 75, 1, 1]]       [1, 75, 1, 1]           300      
      ReLU-11          [[1, 75, 1, 1]]       [1, 75, 1, 1]            0       
     Conv2D-63         [[1, 75, 1, 1]]       [1, 906, 1, 1]        68,856     
     Sigmoid-11        [[1, 906, 1, 1]]      [1, 906, 1, 1]           0       
       SE-11           [[1, 906, 7, 7]]      [1, 906, 7, 7]           0       
      ReLU6-14         [[1, 906, 7, 7]]      [1, 906, 7, 7]           0       
     Conv2D-64         [[1, 906, 7, 7]]      [1, 162, 7, 7]        146,772    
   BatchNorm2D-53      [[1, 162, 7, 7]]      [1, 162, 7, 7]          648      
LinearBottleneck-14    [[1, 151, 7, 7]]      [1, 162, 7, 7]           0       
     Conv2D-65         [[1, 162, 7, 7]]      [1, 972, 7, 7]        157,464    
   BatchNorm2D-54      [[1, 972, 7, 7]]      [1, 972, 7, 7]         3,888     
      Swish-15         [[1, 972, 7, 7]]      [1, 972, 7, 7]           0       
     Conv2D-66         [[1, 972, 7, 7]]      [1, 972, 7, 7]         8,748     
   BatchNorm2D-55      [[1, 972, 7, 7]]      [1, 972, 7, 7]         3,888     
AdaptiveAvgPool2D-12   [[1, 972, 7, 7]]      [1, 972, 1, 1]           0       
     Conv2D-67         [[1, 972, 1, 1]]      [1, 81, 1, 1]         78,813     
   BatchNorm2D-56      [[1, 81, 1, 1]]       [1, 81, 1, 1]           324      
      ReLU-12          [[1, 81, 1, 1]]       [1, 81, 1, 1]            0       
     Conv2D-68         [[1, 81, 1, 1]]       [1, 972, 1, 1]        79,704     
     Sigmoid-12        [[1, 972, 1, 1]]      [1, 972, 1, 1]           0       
       SE-12           [[1, 972, 7, 7]]      [1, 972, 7, 7]           0       
      ReLU6-15         [[1, 972, 7, 7]]      [1, 972, 7, 7]           0       
     Conv2D-69         [[1, 972, 7, 7]]      [1, 174, 7, 7]        169,128    
   BatchNorm2D-57      [[1, 174, 7, 7]]      [1, 174, 7, 7]          696      
LinearBottleneck-15    [[1, 162, 7, 7]]      [1, 174, 7, 7]           0       
     Conv2D-70         [[1, 174, 7, 7]]     [1, 1044, 7, 7]        181,656    
   BatchNorm2D-58     [[1, 1044, 7, 7]]     [1, 1044, 7, 7]         4,176     
      Swish-16        [[1, 1044, 7, 7]]     [1, 1044, 7, 7]           0       
     Conv2D-71        [[1, 1044, 7, 7]]     [1, 1044, 7, 7]         9,396     
   BatchNorm2D-59     [[1, 1044, 7, 7]]     [1, 1044, 7, 7]         4,176     
AdaptiveAvgPool2D-13  [[1, 1044, 7, 7]]     [1, 1044, 1, 1]           0       
     Conv2D-72        [[1, 1044, 1, 1]]      [1, 87, 1, 1]         90,915     
   BatchNorm2D-60      [[1, 87, 1, 1]]       [1, 87, 1, 1]           348      
      ReLU-13          [[1, 87, 1, 1]]       [1, 87, 1, 1]            0       
     Conv2D-73         [[1, 87, 1, 1]]      [1, 1044, 1, 1]        91,872     
     Sigmoid-13       [[1, 1044, 1, 1]]     [1, 1044, 1, 1]           0       
       SE-13          [[1, 1044, 7, 7]]     [1, 1044, 7, 7]           0       
      ReLU6-16        [[1, 1044, 7, 7]]     [1, 1044, 7, 7]           0       
     Conv2D-74        [[1, 1044, 7, 7]]      [1, 185, 7, 7]        193,140    
   BatchNorm2D-61      [[1, 185, 7, 7]]      [1, 185, 7, 7]          740      
LinearBottleneck-16    [[1, 174, 7, 7]]      [1, 185, 7, 7]           0       
     Conv2D-75         [[1, 185, 7, 7]]     [1, 1280, 7, 7]        236,800    
   BatchNorm2D-62     [[1, 1280, 7, 7]]     [1, 1280, 7, 7]         5,120     
      Swish-17        [[1, 1280, 7, 7]]     [1, 1280, 7, 7]           0       
AdaptiveAvgPool2D-14  [[1, 1280, 7, 7]]     [1, 1280, 1, 1]           0       
     Dropout-1        [[1, 1280, 1, 1]]     [1, 1280, 1, 1]           0       
     Conv2D-76        [[1, 1280, 1, 1]]      [1, 10, 1, 1]         12,810     
================================================================================
Total params: 3,570,061
Trainable params: 3,487,305
Non-trainable params: 82,756
--------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 179.95
Params size (MB): 13.62
Estimated Total Size (MB): 194.15
--------------------------------------------------------------------------------






{'total_params': 3570061, 'trainable_params': 3487305}

3. 模型训练

model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters()),
              loss=paddle.nn.CrossEntropyLoss(),
              metrics=paddle.metric.Accuracy())


model.fit(
    train_data=train_dataset, 
    eval_data=val_dataset, 
    batch_size=128, 
    epochs=10, 
    verbose=1, 
addle.metric.Accuracy())


model.fit(
    train_data=train_dataset, 
    eval_data=val_dataset, 
    batch_size=128, 
    epochs=10, 
    verbose=1, 
)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/10


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:648: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")


step 391/391 [==============================] - loss: 1.4491 - acc: 0.3863 - 575ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.7882 - acc: 0.4751 - 456ms/step         
Eval samples: 10000
Epoch 2/10
step 391/391 [==============================] - loss: 1.1541 - acc: 0.5422 - 581ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.2044 - acc: 0.5535 - 452ms/step         
Eval samples: 10000
Epoch 3/10
step 391/391 [==============================] - loss: 0.9998 - acc: 0.6278 - 580ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.9721 - acc: 0.6489 - 445ms/step         
Eval samples: 10000
Epoch 4/10
step 391/391 [==============================] - loss: 0.9672 - acc: 0.6923 - 580ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.9084 - acc: 0.6742 - 463ms/step         
Eval samples: 10000
Epoch 5/10
step 391/391 [==============================] - loss: 0.7523 - acc: 0.7179 - 589ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.0783 - acc: 0.7013 - 457ms/step         
Eval samples: 10000
Epoch 6/10
step 391/391 [==============================] - loss: 0.5859 - acc: 0.7411 - 586ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.7804 - acc: 0.7353 - 448ms/step         
Eval samples: 10000
Epoch 7/10
step 391/391 [==============================] - loss: 0.9060 - acc: 0.7618 - 591ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 1.1076 - acc: 0.7308 - 454ms/step         
Eval samples: 10000
Epoch 8/10
step 391/391 [==============================] - loss: 0.5531 - acc: 0.7673 - 592ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.9240 - acc: 0.7258 - 459ms/step         
Eval samples: 10000
Epoch 9/10
step 391/391 [==============================] - loss: 0.6456 - acc: 0.7740 - 594ms/step        
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.6419 - acc: 0.7475 - 461ms/step         
Eval samples: 10000
Epoch 10/10
step 391/391 [==============================] - loss: 0.4422 - acc: 0.8106 - 597ms/step         
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 79/79 [==============================] - loss: 0.6156 - acc: 0.7785 - 469ms/step         
Eval samples: 10000

总结

  • 本文主要目的是通过对MoblieNet进行小改动,以减少表征瓶颈

  • 本文在ImageNet性能达到77.9

  • 本文美不中足的地方是,作者在训练ReXNet时候用了很多的trick,实际上如果不用各种trick,不用预训练模型,Paddle内置MoblieNet V2拟合能力更好,收敛更快

  • ReXNet推理阶段速度比同FLOPs的MoblieNet V2-1.2要慢,这是因为网络架构问题,本文亮点主要是提出一些设计原则,鼓励NAS搜索更好的网络

  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值