简介
在深度可分离卷积中,在深度卷积(分组卷积)后使用 1×1 卷积融合不同通道信息。本文提出在分组卷积中引入通道混洗,增强分组卷积(通常每组通道数大于一)时不同通道间的信息交互。
0. Abstract
基于 shuffle 模块,ShuffleNet在保持模型精度的前提下大幅减少了计算量。相比于 MobileNet,在计算量均为 40 MFLOPs 左右时错误率降低 7.8%。
1. Introduction
作者指出 X c e p t i o n {\rm Xception} Xception 和 R e s N e t {\rm ResNet} ResNet 等深度网络难以应用在小型网络中的主要原因是其中 1×1 卷积的计算代价。基于深度可分离卷积的思路,作者将深度卷积(逐通道卷积)改为组卷积,并且引入通道混洗模块来加强通道间的信息交互。
2. Related Work
Efficient Model Designs GooLeNet 基于分支结构在增加网络深度的同时保证了较少的计算量;SqueezeNet 提出的 Fire 模块在保证模型精度的同时大幅减少计算量;ResNet 基于残差模块大幅提高了模型的精度;SENet 提出通道注意力机制,以较小代价提高模型的性能;基于强化学习和神经架构搜索,NASNet 获得了与本文模型相当的结果。
Group Convolution 分组卷积即将输入特征图分成若干组,每个组分别计算,然后合并每组的计算结果。这种技术最早在 AlexNet 中使用,由于计算资源的受限,将特征图按通道划分。
Channel Shuffle Operation 尽管 CNN 库的 cuda-convert 支持通道混洗,但以前少有将其应用于网络设计中。
Model Acceleration 对预训练模型的压缩通常采用剪枝、量化、知识蒸馏等。
3. Approach
3.1 Channel Shuffle for Group Convolution
当前主流神经网络结构基于模块化设计,Xception 和 ResNeXt 基于深度可分离卷积和分组卷积构建基本模块。上述网络的设计没有充分考虑 1×1 卷积的特点。如在 ResNeXt 中,将 3×3 卷积使用了分组卷积,这使得 1×1 卷积占据最终模型计算量的 93.4%。
解决上述问题最直接的办法是使用稀疏的通道连接,如在 1×1 卷积层也使用分组卷积。但仅简单地将多个分组卷积块堆叠在一起,这将造成某输出组的信息仅来源于对应的输入组。如下图:
如果某组的输出来自于不同组,那么输入和输出通道将会充分关联。这即是本文提出的通道混洗:
def channel_shuffle(x, groups):
# n,c,h,w
batch_size, num_channels, height, width = x.shape[0:4]
assert num_channels % groups == 0, 'num_channels should be divisible by groups'
# groups指定分组数,channels_per_group表示每组的通道数
channels_per_group = num_channels // groups
# reshape,将x的通道维度拆分
x = paddle.reshape(
x=x, shape=[batch_size, groups, channels_per_group, height, width])
# transpose
x = paddle.transpose(x=x, perm=[0, 2, 1, 3, 4])
# reshape回原x的形状
x = paddle.reshape(x=x, shape=[batch_size, num_channels, height, width])
return x
3.2 ShuffleNet Unit
基于通道混洗构建的 shuffle 模块:
对于输入大小为 c×h×w 的特征图,模块通道数为 m。ResNet 的计算量为 hw(2cm+9mm) ,ResNeXt 的计算量为 hw(2cm+9mm/g),而 ShuffleNet 的计算量为 hw(2cm/g+9m)。
3.3 Network Architecture
基于 shuffle 模块构建的 ShuffleNet:
4. Experiments
不同宽度和不同组数的 ShuffleNet 比较:
与 MobileNet 的比较,摘要部分的叙述来自于最后一栏:
与其他模型的比较:
参考
【1】Zhang X, Zhou X, Lin M, et al. Shufflenet: An extremely efficient convolutional neural network for mobile devices[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 6848-6856.
【2】https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.3/ppdet/modeling/ops.py#L1593