码字不易,欢迎点赞。文章同步发布在公众号:CV前沿
在上一篇文章中,我们已经介绍了一种针对移动端和嵌入式设备的卷积神经网络—-MobileNet。今天我们要介绍的是旷视科技在2017年12月份提出的更加高效的移动端卷积神经网络—-ShuffleNet。
在ShuffleNet网络中使用了两个创新的操作:
-
pointwise group convolution(逐点组卷积)
-
channle shuffle(通道混洗)
概要来说,逐点组卷积是降低了逐点卷积(也即是1*1卷积)的计算复杂度; 同时为了消除多个组卷积堆叠产生的副作用,采用通道混洗的操作来改善跨特征通道的信息流动。使得ShuffleNet网络在保持准确率的情况下,极大的降低了计算成本。ShuffleNet网络在ImageNet竞赛和MS COCO竞赛中均表现了比其他移动端先进网络更优越的性能。
逐点组卷积
我们在文章中介绍移动端不同类型的卷积操作时,其中讲到了普通卷积、组卷积和逐点组卷积这几个概念。
这里再简单回顾一下,组卷积是在输入特征图的通道方向执行分组;逐点组卷积本只是组卷积的一种特殊形式,特殊的地方在于它的卷积核的核大小为1*1。如下图所示。
那么为什么ShuffleNet网络要使用逐点组卷积呢?
这要从ResNext讲起。
我们知道,ResNext就是对ResNet的残差单元进行了微创新,创新点就是组卷积,从而提升了网络的精度。ResNet和ResNext的残差单元如下所示。
ResNet unit 结构图
ResNext unit 结构图
但对于ResNext的一个残差单元来说,逐点卷积(1*1卷积)占整个残差单元计算量的93.4%,可见逐点卷积的计算成本是很高的。那么在小型网络中,为了符合移动端设备上有限的计算资源,代价昂贵的逐点卷积会导致网络整体变窄(每层通道数变少),进而可能带来精度的大幅度下降。
总结来说,由于逐点卷积的昂贵的计算开销,使得网络在满足设备计算资源情况下,无法满足精度需求。从而这种网络结构不适应于移动端和嵌入式设备。
为了解决这个问题,本文提出了逐点组卷积。逐点组卷积的特点是通过保证每个卷积操作仅仅是作用在对应的输入通道组,使得大大减少了计算量。
我们根据对组卷积的知识,假设组个数为 g g g, 那么逐点组卷积相对于逐点卷积,计算量下降了 g g g倍。
但是,如果多个组卷积堆叠在一起,会产生一个副作用:某个通道的输出结果,仅来自于一小部分输入通道。如下图图(a)。
这个副作用会导致在组与组之间信息流动的阻塞,以及表达能力的弱化。
那么我们如何解决这个问题呢? 这用到了本文的第二个创新—通道混洗。
通道混洗
如图所示。通过对比图(a)和图(b),我们在第一个逐点组卷积之后,对输出的结果的通道次序进行打乱,比如原始在通道维度上的索引是0,1,2,3,4,5,6,7,8;那么打乱后变为了0,3,6,1,4,7,2,5,8。
经过这样打乱之后,输出通道就不再仅仅来自于是一小部分输入通道,也会来自其他的通道。即输出通道和输入通道完全的关联了。形成的效果如图(c)所示。
通道混洗的操作实现了多组卷积层的跨组信息流动。
ShuffleNet Unit
利用逐点组卷积和通道打乱的操作,我们可以建立一个针对于小型网络特别设计的ShuffleNet unit。
最初设计是一个残差块(residual block)如图(a)所示。
然后在(a)的残差分支中,对于其中的3 * 3卷积层,我们应用一个计算成本低的3 * 3 的DW卷积。
然后我们替换了第一个逐点卷积,改为逐点组卷积,后面跟一个通道混洗操作。如图(b)所示。
图(b)中的第二个逐点组卷积的目的是为了恢复输出通道数量(升维),从而和输入通道数量一致,以便能够实现与捷径分支的输出结果进行相加操作。(相加要求两个分支的特征图在宽度、高度和深度均一致)
这就形成了一个完整的ShuffleNet unit了。
此外,我们知道卷积神经网络都需要有降采样操作,一种实现方式是采用最大池化层,另一种做法是使用stride=2的卷积来实现。
在ShuffleNet unit中,同样是采用的stride=2的卷积。如图(c)所示。具体做法是分别在捷径分支的分支设置stride=2和主分支的3*3 DW卷积中设置stride=2,从而既能够实现降采样的操作,同时又能够实现两个分支输出结果的融合。
这里还需要注意的两点是:
-
捷径分支上它采样的是3 * 3的平均池化。
-
融合没有采用相加的方法,而是通道方向的拼接。文中介绍到这样是为了更容易以很少的计算量来扩大通道维度。
与ResNet,ResNext计算量的对比
假设输入特征图形状是 c ∗ h ∗ w c*h*w c∗h∗w,以及bottleneck 模块的通道数是m,那么三个网络的残差单元的计算量分别是:
ResNet unit:
1 ∗ 1 ∗ c ∗ m ∗ h ∗ w + 3 ∗ 3 ∗ m ∗ m ∗ h ∗ w + 1 ∗ 1 ∗ m ∗ c ∗ h ∗ w = h w ( 2 c m + 9 m