参考代码:dmcp
1. 概述
导读:在网络剪枝领域中已经有一些工作将结构搜索的概念引入到剪枝方法中,如AMC使用强化学习的方式使控制器输出每一层的裁剪比例。但是正如NAS的发展过程一样,这些基于强化学习的搜索方法需要大量的训练和验证过程,而最直接的便是使用类似DARTS的直接梯度优化方法。而这篇文章中将网络的channel剪枝建模为一个可微分的马尔可夫过程(Differentiable Markov Channel Pruning,DMCP),通过直接产生的梯度信息进行优化,因而更加高效,整个剪枝的过程也变为了一个马尔可夫决策过程。为了减少额外参数的引入,文章还将卷积层的channel进行分组,组内实行参数共享,从而减少参数量。使用文章的方法在MobileNet-v2上剪裁掉了30%的参数,掉点0.1%;在ResNet-50上剪裁掉了44%的参数,掉点0.4%。
网络剪枝的过程可以看作是在一个现有搜索空间中去搜索一个最小的子结构,但需要保持网络的输出性能相差不大。对于网络搜索,在之前的NAS工作中已经将可微分的概念引入到搜索过程中,如DARTS,但是网络剪枝过程中却无法直接使用,这是因为:
- 1)搜索空间不同:DARTS中的搜索空间是预先定义好的网络结构,而网络剪枝中却是网络每层的channel;
- 2)“元素”之间的关系不同:在DARTS中这些“元素”是相互独立的,而在网络剪枝过程中其却是存在依赖关系的,如要保留第 k + 1 k+1 k+1个channel那么前面的 k k k个channel应该是保留的;
正如上面提到在网络层中channel是前向依赖的,而这一点性质与马尔可夫的决策过程接近,因而文章将网络剪枝的过程抽象为一个马尔可夫决策过程,使用 S k S_k Sk表示保留 k t h k^{th} kth个channel的状态,从状态 S k S_k Sk到 S k + 1 S_{k+1} Sk+1(保留 ( k + 1 ) t h (k+1)^{th} (k+1)th)是存在状态转移的,因而对于每一个channel就可以得到一个状态转移的概率,因而就可以根据这个概率来选择特征图中的channel。在实际操作中会channel与对应的概率相乘从而去控制剪裁。在此基础上通过概率模型构建出来一个带参数的采样空间,从而这个过程便变得可微分,就使得可以使用一个目标FLOPs去约束,从而达到剪枝的目的。
将文章的方法与之前的一些方法进行比较,见下图所示:
2. 方法设计
2.1 整体Pipline
文章的整体网络方法pipline如下图所示:
上图的a部分表示的文章训练的两个stage:
- 1)stage1:固定结构参数,通过参数共享训练4个子网络(最小/最大/两个随机,通过参数共享形式在一个卷积里面选择不同数量channel实现),引入扰动寻找最优的搜索结构;
- 2)stage2:将网络设置为最大,训练网络的结构参数,通过训练约束网络结构使其在channel上“稀疏”;
上图的b部分是数据在一个Conv+BN+ReLU结构中的融合过程。
2.2 基于马尔可夫过程的剪裁
对于一个网络层 L ( i ) L^{(i)} L(i),其输出的channel维度为 C o u t ( i ) C_{out}^{(i)} Cout(i),其输出描述为下面的形式:
O k ( i ) = w k ( i ) ⊙ x , k = 1 , 2 , … , C o u t ( i ) O_k^{(i)}=w_k^{(i)}\odot x,k=1,2,\dots,C_{out}^{(i)} Ok(i)=wk(i)⊙x,k=1,2,…,Cout(i)
其中, w k ( i ) w_k^{(i)} wk(i)表示参数。由于网络层channel的性质,文章将其在channel维度上进行建模,构建一个马尔可夫过程,描述为下图所示:
当前 k k k的channel与之前的 k − 1 k-1 k−1个channel是相关的,并且之前的 k − 1 k-1 k−1个channel应该是确定存在的。对于前面的 k − 1 k-1 k−1个channel其保留的概率描述为 p ( w 1 , w 2 , … , w k − 1 ) p(w1,w2,\dots,w_{k-1}) p(w1,w2,…,wk−1),那么第 k k k个channel在该条件下被保留下来的概率为:
p ( w 1 , w 2 , … , w k ) = p ( w k ∣ w 1 , w 2 , … , w k − 1 ) p ( w 1 , w 2 , … , w k − 1 ) p(w1,w2,\dots,w_k)=p(w_k|w1,w2,\dots,w_{k-1})p(w1,w2,\dots,w_{k-1}) p(w1,w2,…,w