资源约束下的结构化剪枝技术
引言
结构化剪枝是神经网络压缩领域中的一项关键技术,它通过有选择地移除整个结构单元(如通道、滤波器或层)来减小模型尺寸并提高计算效率。在现实应用中,我们往往面临着严格的资源限制,比如内存容量、计算能力、能耗或延迟要求。
结构化剪枝基础
神经网络剪枝可分为非结构化剪枝和结构化剪枝。非结构化剪枝针对单个权重进行操作,虽然理论上可以实现高压缩率,但难以在通用硬件上获得实际加速。相比之下,结构化剪枝移除整个结构单元,虽然压缩率可能较低,但能直接转化为计算加速。
在卷积神经网络中,一个典型的卷积层可表示为:
Y = W ∗ X \mathbf{Y} = \mathbf{W} * \mathbf{X} Y=W∗X
其中 X ∈ R C i n × H i n × W i n \mathbf{X} \in \mathbb{R}^{C_{in} \times H_{in} \times W_{in}} X∈RCin×Hin×Win是输入特征图, W ∈ R C o u t × C i n × K × K \mathbf{W} \in \mathbb{R}^{C_{out} \times C_{in} \times K \times K} W∈RCout×Cin×K×K是卷积核, Y ∈ R C o u t × H o u t × W o u t \mathbf{Y} \in \mathbb{R}^{C_{out} \times H_{out} \times W_{out}} Y∈RCout×Hout×Wout是输出特征图, ∗ * ∗表示卷积操作。
从计算复杂度角度看,卷积操作的FLOPs可表示为:
FLOPs ( W , X ) = 2 ⋅ C o u t ⋅ C i n ⋅ K 2 ⋅ H o u t ⋅ W o u t \text{FLOPs}(\mathbf{W}, \mathbf{X}) = 2 \cdot C_{out} \cdot C_{in} \cdot K^2 \cdot H_{out} \cdot W_{out} FLOPs(W,X)=2⋅Cout⋅Cin⋅K2⋅Hout⋅Wout
结构化剪枝后,如果保留 C o u t ′ C_{out}' Cout′个输出通道和 C i n ′ C_{in}' Cin′个输入通道,则计算复杂度降为:
FLOPs ( W ′ , X ′ ) = 2 ⋅ C o u t ′ ⋅ C i n ′ ⋅ K 2 ⋅ H o u t ⋅ W o u t \text{FLOPs}(\mathbf{W}', \mathbf{X}') = 2 \cdot C_{out}' \cdot C_{in}' \cdot K^2 \cdot H_{out} \cdot W_{out} FLOPs(W′,X′)=2⋅Cout′⋅Cin′⋅K2⋅Hout⋅Wout
从信息论角度考虑,我们可以通过特征图熵来衡量通道信息量:
H ( Y j ) = − ∑ h , w , b P ( Y j h , w , b ) log P ( Y j h , w , b ) H(\mathbf{Y}_j) = -\sum_{h,w,b} P(\mathbf{Y}_j^{h,w,b}) \log P(\mathbf{Y}_j^{h,w,b}) H(Yj)=−h,w,b∑P(Yjh,w,b)logP(Yjh,w,b)
其中 P ( Y j h , w , b ) P(\mathbf{Y}_j^{h,w,b}) P(Yjh,w,b)是特征图第 j j j通道在位置 ( h , w ) (h,w) (h,w)和批次 b b b上的归一化激活值概率分布。
资源约束建模
在实际应用中,我们通常需要在满足特定资源约束的条件下进行剪枝。这些约束可以表示为:
R ( W ′ ) ≤ R t a r g e t R(\mathbf{W}') \leq R_{target} R(W′)≤Rtarget
其中 W ′ \mathbf{W}' W′是剪枝后的模型参数, R ( ⋅ ) R(\cdot) R(⋅)是资源测量函数, R t a r g e t R_{target} Rtarget是目标资源限制。
多目标资源约束可以表示为向量形式:
R ( W ′ ) ⪯ R t a r g e t \mathbf{R}(\mathbf{W}') \preceq \mathbf{R}_{target} R(W′)⪯Rtarget
其中 R ( W ′ ) = [ R 1 ( W ′ ) , R 2 ( W ′ ) , … , R m ( W ′ ) ] T \mathbf{R}(\mathbf{W}') = [R_1(\mathbf{W}'), R_2(\mathbf{W}'), \ldots, R_m(\mathbf{W}')]^T R(W′)=[R1(W′),R2(W′),…,Rm(W′)]T是 m m m种资源测量的向量, ⪯ \preceq ⪯表示逐元素小于等于。
考虑参数量、计算复杂度和延迟等多种资源约束时,我们可以构建加权约束:
∑ i = 1 m ω i ⋅ R i ( W ′ ) R i , t a r g e t ≤ 1 \sum_{i=1}^m \omega_i \cdot \frac{R_i(\mathbf{W}')}{R_{i,target}} \leq 1 i=1∑mωi⋅Ri,targetRi(W′)≤1
其中 ω i \omega_i ωi是每种资源约束的权重系数,满足 ∑ i = 1 m ω i = 1 \sum_{i=1}^m \omega_i = 1 ∑i=1mωi=1。
资源与网络结构之间的关系可以通过函数 R ( c ) R(\mathbf{c}) R(c)建模,其中 c = [ c 1 , c 2 , … , c L ] \mathbf{c} = [c_1, c_2, \ldots, c_L] c=[c1,c2,…,cL]表示各层的通道数。对于计算复杂度,可以表示为:
R F L O P s ( c ) = ∑ l = 1 L − 1 2 ⋅ c l ⋅ c l + 1 ⋅ K l 2 ⋅ H l ⋅ W l R_{FLOPs}(\mathbf{c}) = \sum_{l=1}^{L-1} 2 \cdot c_l \cdot c_{l+1} \cdot K_l^2 \cdot H_l \cdot W_l RFLOPs(c)=l=1∑L−12⋅cl⋅cl+1⋅Kl2⋅Hl⋅Wl
结构化剪枝的数学公式化
剪枝优化目标
结构化剪枝可以形式化为一个约束优化问题:
min W ′ L ( W ′ , D ) s.t. R ( W ′ ) ≤ R t a r g e t \min_{\mathbf{W}'} \mathcal{L}(\mathbf{W}', \mathcal{D}) \quad \text{s.t.} \quad R(\mathbf{W}') \leq R_{target} W′minL(W′,D)s.t.R(W′)≤Rtarget
这个问题可以转化为拉格朗日形式:
L t o t a l ( W ′ ) = L ( W ′ , D ) + λ ⋅ max ( 0 , R ( W ′ ) − R t a r g e t ) \mathcal{L}_{total}(\mathbf{W}') = \mathcal{L}(\mathbf{W}', \mathcal{D}) + \lambda \cdot \max(0, R(\mathbf{W}') - R_{target}) Ltotal(W′)=L(W′,D)+λ⋅max(0,R(W′)−Rtarget)
考虑二值掩码变量,我们可以进一步将问题转化为:
min M , W ′ L ( M ⊙ W ′ , D ) s.t. ∥ M ∥ 0 ≤ k \min_{\mathbf{M}, \mathbf{W}'} \mathcal{L}(\mathbf{M} \odot \mathbf{W}', \mathcal{D}) \quad \text{s.t.} \quad \|\mathbf{M}\|_0 \leq k M,W′minL(M⊙W′,D)s.t.∥M∥0≤k
其中 M \mathbf{M} M是结构掩码, ⊙ \odot ⊙表示逐元素乘法, k k k是保留的结构单元数量。
由于 L 0 L_0 L0范数优化是NP-hard问题,我们可以通过连续松弛将其近似为:
min M , W ′ L ( M ⊙ W ′ , D ) + λ ⋅ ∥ M ∥ p p \min_{\mathbf{M}, \mathbf{W}'} \mathcal{L}(\mathbf{M} \odot \mathbf{W}', \mathcal{D}) + \lambda \cdot \|\mathbf{M}\|_p^p M,W′minL(M⊙W′,D)+λ⋅∥M∥pp
其中 p p p可以是0.5、1或2,分别对应不同程度的稀疏性近似。当 p → 0 p \to 0 p→0时, ∥ M ∥ p p \|\mathbf{M}\|_p^p ∥M∥pp越接近 ∥ M ∥ 0 \|\mathbf{M}\|_0 ∥M∥0。
从贝叶斯角度看,我们可以将结构掩码 M \mathbf{M} M视为随机变量,并引入变分推断框架:
min ϕ E q ϕ ( M ) [ L ( M ⊙ W ′ , D ) ] + λ ⋅ KL ( q ϕ ( M ) ∥ p ( M ) ) \min_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{M})}[\mathcal{L}(\mathbf{M} \odot \mathbf{W}', \mathcal{D})] + \lambda \cdot \text{KL}(q_{\phi}(\mathbf{M}) \| p(\mathbf{M})) ϕminEqϕ(M)[L(M⊙W′,D)]+λ⋅KL(qϕ(M)∥p(M))
其中 q ϕ ( M ) q_{\phi}(\mathbf{M}) qϕ(M)是掩码的变分后验分布, p ( M ) p(\mathbf{M}) p(M)是先验分布。
通道剪枝数学表达
以通道剪枝为例,考虑第 l l l层的输出通道剪枝,可以引入通道重要性度量 s j l s_j^l sjl:
s j l = I ( W j l , X l , Y l ) s_j^l = I(\mathbf{W}_j^l, \mathbf{X}^l, \mathbf{Y}^l) sjl=I(Wjl,Xl,Yl)
其中 I ( ⋅ ) I(\cdot) I(⋅)是重要性评估函数, W j l \mathbf{W}_j^l Wjl是第 l l l层第 j j j个输出通道的参数。
基于此,我们可以定义通道掩码:
m j l = { 1 , if s j l > τ l 0 , otherwise m_j^l = \begin{cases} 1, & \text{if } s_j^l > \tau^l \\ 0, & \text{otherwise} \end{cases} mjl={1,0,if sjl>τlotherwise
其中 τ l \tau^l τl是第 l l l层的剪枝阈值。为确定最优阈值,我们可以将其表述为二分搜索问题:
τ l = arg min τ ∣ ∑ j = 1 C l I ( s j l > τ ) − ( 1 − r l ) ⋅ C l ∣ \tau^l = \argmin_{\tau} \left| \sum_{j=1}^{C_l} \mathbb{I}(s_j^l > \tau) - (1-r^l) \cdot C_l \right| τl=τargmin j=1∑ClI(sjl>τ)−(1−rl)⋅Cl
其中 r l r^l rl是第 l l l层的目标剪枝率, I ( ⋅ ) \mathbb{I}(\cdot) I(⋅)是指示函数。
为了处理二元掩码的不可微性,我们可以引入软掩码近似:
m ~ j l = σ ( α ⋅ ( s j l − τ l ) ) \tilde{m}_j^l = \sigma\left(\alpha \cdot (s_j^l - \tau^l)\right) m~jl=σ(α⋅(sjl−τl))
其中 σ \sigma σ是sigmoid函数, α \alpha α是控制近似陡度的参数。当 α → ∞ \alpha \to \infty α→∞时, m ~ j l \tilde{m}_j^l m~jl趋近于 m j l m_j^l mjl。
常见的结构重要性评估方法
基于范数的方法
最简单直观的方法是使用权重范数来评估结构重要性:
s j l = ∥ W j l ∥ p s_j^l = \|\mathbf{W}_j^l\|_p sjl=∥Wjl∥p
对于第 l l l层的第 j j j个卷积滤波器,其 L p L_p Lp范数重要性可表示为:
s j l = ( ∑ c = 1 C i n ∑ h = 1 K ∑ w = 1 K ∣ W j , c , h , w l ∣ p ) 1 / p s_j^l = \left( \sum_{c=1}^{C_{in}} \sum_{h=1}^K \sum_{w=1}^K |W_{j,c,h,w}^l|^p \right)^{1/p} sjl=(c=1∑Cinh=1∑Kw=1∑K∣Wj,c,h,wl∣p)1/p
当采用混合范数时,可以更精细地捕获结构特性:
s j l = ∑ c = 1 C i n ∥ W j , c , : , : l ∥ F s_j^l = \sum_{c=1}^{C_{in}} \|\mathbf{W}_{j,c,:,:}^l\|_F sjl=c=1∑Cin∥Wj,c,:,:l∥F
其中 ∥ W j , c , : , : l ∥ F \|\mathbf{W}_{j,c,:,:}^l\|_F ∥Wj,c,:,:l∥F是第 j j j个滤波器与第 c c c个输入通道连接的卷积核的Frobenius范数。
基于特征图的方法
特征图的统计信息也可用于评估通道重要性:
s j l = 1 N ∑ i = 1 N ∥ Y j , i l ∥ 1 s_j^l = \frac{1}{N} \sum_{i=1}^N \|\mathbf{Y}_{j,i}^l\|_1 sjl=N1i=1∑N∥Yj,il∥1
引入信息熵测度,可以评估特征图通道的信息量:
s j l = H ( Y j l ) = − ∑ h = 1 H o u t ∑ w = 1 W o u t ∑ i = 1 N P ( Y j , i , h , w l ) log P ( Y j , i , h , w l ) s_j^l = H(\mathbf{Y}_j^l) = -\sum_{h=1}^{H_{out}} \sum_{w=1}^{W_{out}} \sum_{i=1}^N P(\mathbf{Y}_{j,i,h,w}^l) \log P(\mathbf{Y}_{j,i,h,w}^l) sjl=H(Yjl)=−h=1∑Houtw=1∑Wouti=1∑NP(Yj,i,h,wl)logP(Yj,i,h,wl)
其中 P ( Y j , i , h , w l ) P(\mathbf{Y}_{j,i,h,w}^l) P(Yj,i,h,wl)是归一化后的激活值概率。
考虑特征图的空间相关性,我们可以引入空间注意力重要性度量:
s j l = 1 N ∑ i = 1 N ∥ ∑ h , w ∣ Y j , i , h , w l ∣ ⋅ ( h , w ) ∑ h , w ∣ Y j , i , h , w l ∣ − μ j l ∥ 2 2 s_j^l = \frac{1}{N} \sum_{i=1}^N \left\| \frac{\sum_{h,w} |\mathbf{Y}_{j,i,h,w}^l| \cdot (h,w)}{\sum_{h,w} |\mathbf{Y}_{j,i,h,w}^l|} - \mu_j^l \right\|_2^2 sjl=N1i=1∑N ∑h,w∣Yj,i,h,wl∣∑h,w∣Yj,i,h,wl∣⋅(h,w)−μjl 22
其中 μ j l \mu_j^l μjl是空间注意力的均值向量。
基于梯度的方法
考虑特征图对损失函数的影响,可以使用梯度信息:
s j l = ∥ ∂ L ∂ Y j l ∥ F 2 s_j^l = \left\|\frac{\partial \mathcal{L}}{\partial \mathbf{Y}_j^l}\right\|_F^2 sjl= ∂Yjl∂L F2
更深入地,我们可以考虑权重与梯度的乘积:
s j l = ∣ W j l ⋅ ∂ L ∂ W j l ∣ = ∣ ∑ c = 1 C i n ∑ h = 1 K ∑ w = 1 K W j , c , h , w l ⋅ ∂ L ∂ W j , c , h , w l ∣ s_j^l = \left| \mathbf{W}_j^l \cdot \frac{\partial \mathcal{L}}{\partial \mathbf{W}_j^l} \right| = \left| \sum_{c=1}^{C_{in}} \sum_{h=1}^K \sum_{w=1}^K W_{j,c,h,w}^l \cdot \frac{\partial \mathcal{L}}{\partial W_{j,c,h,w}^l} \right| sjl= Wjl⋅∂Wjl∂L = c=1∑Cinh=1∑Kw=1∑KWj,c,h,wl⋅∂Wj,c,h,wl∂L
引入二阶信息,可以使用Fisher信息矩阵:
s j l = W j l ⋅ F j l ⋅ W j l s_j^l = \mathbf{W}_j^l \cdot \mathbf{F}_j^l \cdot \mathbf{W}_j^l sjl=Wjl⋅Fjl⋅Wjl
其中 F j l = E [ ( ∂ L ∂ W j l ) ( ∂ L ∂ W j l ) T ] \mathbf{F}_j^l = \mathbb{E}\left[ \left( \frac{\partial \mathcal{L}}{\partial \mathbf{W}_j^l} \right) \left( \frac{\partial \mathcal{L}}{\partial \mathbf{W}_j^l} \right)^T \right] Fjl=E[(∂Wjl∂L)(∂Wjl∂L)T]是Fisher信息矩阵。
基于Taylor展开的方法
更复杂的方法利用损失函数对参数的Taylor展开近似:
Δ L ( W j l = 0 ) ≈ L ( W ) − L ( W , W j l = 0 ) \Delta \mathcal{L}(\mathbf{W}_j^l=0) \approx \mathcal{L}(\mathbf{W}) - \mathcal{L}(\mathbf{W}, \mathbf{W}_j^l=0) ΔL(Wjl=0)≈L(W)−L(W,Wjl=0)
一阶泰勒展开:
Δ L ( W j l = 0 ) ≈ W j l ⋅ ∂ L ∂ W j l \Delta \mathcal{L}(\mathbf{W}_j^l=0) \approx \mathbf{W}_j^l \cdot \frac{\partial \mathcal{L}}{\partial \mathbf{W}_j^l} ΔL(Wjl=0)≈Wjl⋅∂Wjl∂L
二阶泰勒展开:
Δ L ( W j l = 0 ) ≈ W j l ⋅ ∂ L ∂ W j l + 1 2 W j l ⋅ H j l ⋅ W j l \Delta \mathcal{L}(\mathbf{W}_j^l=0) \approx \mathbf{W}_j^l \cdot \frac{\partial \mathcal{L}}{\partial \mathbf{W}_j^l} + \frac{1}{2} \mathbf{W}_j^l \cdot \mathbf{H}_j^l \cdot \mathbf{W}_j^l ΔL(Wjl=0)≈Wjl⋅∂Wjl∂L+21Wjl⋅Hjl⋅Wjl
其中 H j l = ∂ 2 L ∂ ( W j l ) 2 \mathbf{H}_j^l = \frac{\partial^2 \mathcal{L}}{\partial (\mathbf{W}_j^l)^2} Hjl=∂(Wjl)2∂2L是Hessian矩阵。
由于计算完整Hessian矩阵成本高昂,我们可以使用Hutchinson方法近似计算矩阵的迹:
tr ( H j l ) ≈ E v ∼ N ( 0 , I ) [ v T H j l v ] \text{tr}(\mathbf{H}_j^l) \approx \mathbb{E}_{\mathbf{v} \sim \mathcal{N}(0, \mathbf{I})}[\mathbf{v}^T \mathbf{H}_j^l \mathbf{v}] tr(Hjl)≈Ev∼N(0,I)[vTHjlv]
其中 v \mathbf{v} v是从标准正态分布采样的随机向量。
资源约束下的全局优化
在资源约束下,简单地对每层独立剪枝可能不是最优策略。全局优化方法考虑各层之间的相互影响:
min { M l } l = 1 L ∑ l = 1 L ∑ j = 1 C l s j l ⋅ ( 1 − m j l ) s.t. R ( { M l } l = 1 L ) ≤ R t a r g e t \min_{\{\mathbf{M}^l\}_{l=1}^L} \sum_{l=1}^L \sum_{j=1}^{C_l} s_j^l \cdot (1-m_j^l) \quad \text{s.t.} \quad R(\{\mathbf{M}^l\}_{l=1}^L) \leq R_{target} {Ml}l=1Lminl=1∑Lj=1∑Clsjl⋅(1−mjl)s.t.R({Ml}l=1L)≤Rtarget
这可以通过拉格朗日乘子法转化为:
min { M l } l = 1 L ∑ l = 1 L ∑ j = 1 C l s j l ⋅ ( 1 − m j l ) + λ ⋅ ( R ( { M l } l = 1 L ) − R t a r g e t ) \min_{\{\mathbf{M}^l\}_{l=1}^L} \sum_{l=1}^L \sum_{j=1}^{C_l} s_j^l \cdot (1-m_j^l) + \lambda \cdot (R(\{\mathbf{M}^l\}_{l=1}^L) - R_{target}) {Ml}l=1Lminl=1∑Lj=1∑Clsjl⋅(1−mjl)+λ⋅(R({Ml}l=1L)−Rtarget)
其中 λ \lambda λ是拉格朗日乘子,可通过迭代方法求解。
对于多层感知机网络,资源约束(如FLOPs)可以表示为:
R ( { M l } l = 1 L ) = ∑ l = 1 L − 1 ∑ i = 1 n l ∑ j = 1 n l + 1 m i l ⋅ m j l + 1 R(\{\mathbf{M}^l\}_{l=1}^L) = \sum_{l=1}^{L-1} \sum_{i=1}^{n_l} \sum_{j=1}^{n_{l+1}} m_i^l \cdot m_j^{l+1} R({Ml}l=1L)=l=1∑L−1i=1∑nlj=1∑nl+1mil⋅mjl+1
其中 n l n_l nl是第 l l l层的神经元数量。
网络敏感性分析可以帮助我们判断各层对模型性能的影响。如果移除第 l l l层的第 j j j个通道,损失函数的期望增量可以表示为:
E [ Δ L j l ] = E [ L ( W , M ⊙ 1 − e j l ) − L ( W , M ) ] \mathbb{E}[\Delta \mathcal{L}_{j}^l] = \mathbb{E}[\mathcal{L}(\mathbf{W}, \mathbf{M} \odot \mathbf{1} - \mathbf{e}_j^l) - \mathcal{L}(\mathbf{W}, \mathbf{M})] E[ΔLjl]=E[L(W,M⊙1−ejl)−L(W,M)]
其中 e j l \mathbf{e}_j^l ejl是第 l l l层第 j j j个通道的单位向量。我们可以通过敏感性为权重的优化目标:
min { M l } l = 1 L ∑ l = 1 L ∑ j = 1 C l E [ Δ L j l ] ⋅ ( 1 − m j l ) s.t. R ( { M l } l = 1 L ) ≤ R t a r g e t \min_{\{\mathbf{M}^l\}_{l=1}^L} \sum_{l=1}^L \sum_{j=1}^{C_l} \mathbb{E}[\Delta \mathcal{L}_{j}^l] \cdot (1-m_j^l) \quad \text{s.t.} \quad R(\{\mathbf{M}^l\}_{l=1}^L) \leq R_{target} {Ml}l=1Lminl=1∑Lj=1∑ClE[ΔLjl]⋅(1−mjl)s.t.R({Ml}l=1L)≤Rtarget
在实际操作中,可以通过动态规划解决约束优化问题。定义 D P [ l ] [ r ] DP[l][r] DP[l][r]为前 l l l层在资源约束 r r r下的最小损失,状态转移方程为:
D P [ l ] [ r ] = min r l ≤ r { D P [ l − 1 ] [ r − r l ] + L o s s ( l , r l ) } DP[l][r] = \min_{r_l \leq r} \{DP[l-1][r-r_l] + Loss(l, r_l)\} DP[l][r]=rl≤rmin{DP[l−1][r−rl]+Loss(l,rl)}
其中 L o s s ( l , r l ) Loss(l, r_l) Loss(l,rl)是第 l l l层在资源约束 r l r_l rl下的最优剪枝损失。
ADMM辅助优化
交替方向乘子法(ADMM)可以有效解决结构化剪枝中的约束优化问题。首先将原问题重新表述为:
min W , Z L ( W , D ) s.t. W = Z , Z ∈ C \min_{\mathbf{W}, \mathbf{Z}} \mathcal{L}(\mathbf{W}, \mathcal{D}) \quad \text{s.t.} \quad \mathbf{W} = \mathbf{Z}, \mathbf{Z} \in \mathcal{C} W,ZminL(W,D)s.t.W=Z,Z∈C
其中 C \mathcal{C} C是满足结构约束的参数空间。增广拉格朗日函数为:
L ρ ( W , Z , U ) = L ( W , D ) + ρ 2 ∥ W − Z + U ∥ F 2 − ρ 2 ∥ U ∥ F 2 \mathcal{L}_{\rho}(\mathbf{W}, \mathbf{Z}, \mathbf{U}) = \mathcal{L}(\mathbf{W}, \mathcal{D}) + \frac{\rho}{2}\|\mathbf{W} - \mathbf{Z} + \mathbf{U}\|_F^2 - \frac{\rho}{2}\|\mathbf{U}\|_F^2 Lρ(W,Z,U)=L(W,D)+2ρ∥W−Z+U∥F2−2ρ∥U∥F2
ADMM将问题分解为交替优化步骤:
-
更新 W \mathbf{W} W(网络训练):
W k + 1 = arg min W L ( W , D ) + ρ 2 ∥ W − Z k + U k ∥ F 2 \mathbf{W}^{k+1} = \arg\min_{\mathbf{W}} \mathcal{L}(\mathbf{W}, \mathcal{D}) + \frac{\rho}{2}\|\mathbf{W} - \mathbf{Z}^k + \mathbf{U}^k\|_F^2 Wk+1=argWminL(W,D)+2ρ∥W−Zk+Uk∥F2 -
更新 Z \mathbf{Z} Z(结构投影):
Z k + 1 = Π C ( W k + 1 + U k ) \mathbf{Z}^{k+1} = \Pi_{\mathcal{C}}(\mathbf{W}^{k+1} + \mathbf{U}^k) Zk+1=ΠC(Wk+1+Uk) -
更新拉格朗日乘子:
U k + 1 = U k + W k + 1 − Z k + 1 \mathbf{U}^{k+1} = \mathbf{U}^k + \mathbf{W}^{k+1} - \mathbf{Z}^{k+1} Uk+1=Uk+Wk+1−Zk+1
其中 Π C \Pi_{\mathcal{C}} ΠC是将参数投影到结构约束空间的操作, ρ \rho ρ是惩罚参数。
对于组稀疏约束,投影操作可以表示为:
Π C ( V ) = arg min Z ∥ Z − V ∥ F 2 s.t. ∑ j = 1 C l I ( ∥ Z j l ∥ F > 0 ) ≤ ( 1 − r l ) ⋅ C l , ∀ l \Pi_{\mathcal{C}}(\mathbf{V}) = \arg\min_{\mathbf{Z}} \|\mathbf{Z} - \mathbf{V}\|_F^2 \quad \text{s.t.} \quad \sum_{j=1}^{C_l} \mathbb{I}(\|\mathbf{Z}_j^l\|_F > 0) \leq (1-r^l) \cdot C_l, \forall l ΠC(V)=argZmin∥Z−V∥F2s.t.j=1∑ClI(∥Zjl∥F>0)≤(1−rl)⋅Cl,∀l
这个投影问题可以通过贪心算法求解:对于每层,保留 ∥ V j l ∥ F \|\mathbf{V}_j^l\|_F ∥Vjl∥F最大的 ( 1 − r l ) ⋅ C l (1-r^l) \cdot C_l (1−rl)⋅Cl个通道。
ADMM的收敛性可通过定理保证:对于凸目标函数,ADMM在适当条件下以 O ( 1 / k ) O(1/k) O(1/k)的速率收敛到全局最优解;对于非凸问题,ADMM收敛到局部最优解。收敛条件可以表述为:
∥ W k + 1 − W k ∥ F 2 + ∥ Z k + 1 − Z k ∥ F 2 < ϵ \|\mathbf{W}^{k+1} - \mathbf{W}^{k}\|_F^2 + \|\mathbf{Z}^{k+1} - \mathbf{Z}^{k}\|_F^2 < \epsilon ∥Wk+1−Wk∥F2+∥Zk+1−Zk∥F2<ϵ
资源感知的自动通道剪枝
为了更精确地满足资源约束,可以引入资源感知的自动剪枝框架:
min W , α L ( W , α , D ) + λ ⋅ ∣ R ( α ) − R t a r g e t ∣ \min_{\mathbf{W}, \alpha} \mathcal{L}(\mathbf{W}, \alpha, \mathcal{D}) + \lambda \cdot |R(\alpha) - R_{target}| W,αminL(W,α,D)+λ⋅∣R(α)−Rtarget∣
其中 α \alpha α是可学习的通道重要性参数, R ( α ) R(\alpha) R(α)是基于当前通道配置的资源估计。
具体而言,可以使用Gumbel-Softmax技巧将离散的通道选择转化为可微的形式:
α ^ j l = exp ( ( log α j l + g j ) / τ ) ∑ j ′ exp ( ( log α j ′ l + g j ′ ) / τ ) \hat{\alpha}_j^l = \frac{\exp((\log \alpha_j^l + g_j) / \tau)}{\sum_{j'} \exp((\log \alpha_{j'}^l + g_{j'}) / \tau)} α^jl=∑j′exp((logαj′l+gj′)/τ)exp((logαjl+gj)/τ)
其中 g j g_j gj是从Gumbel分布中采样的噪声, τ \tau τ是温度参数,随着训练逐渐降低。
通道配置的软指示函数可以表示为:
m j l ( α ) = σ ( γ ⋅ ( α j l − β l ) ) m_j^l(\alpha) = \sigma\left(\gamma \cdot (\alpha_j^l - \beta_l)\right) mjl(α)=σ(γ⋅(αjl−βl))
其中 σ \sigma σ是sigmoid函数, γ \gamma γ是缩放因子, β l \beta_l βl是第 l l l层的剪枝阈值。
训练过程中,梯度可以通过直通估计器(Straight-Through Estimator, STE)传播:
∂ L ∂ α j l = ∂ L ∂ m j l ⋅ ∂ m j l ∂ α j l ≈ ∂ L ∂ m j l ⋅ I ( ∣ α j l − β l ∣ < ϵ ) ⋅ γ ⋅ σ ′ ( γ ⋅ ( α j l − β l ) ) \frac{\partial \mathcal{L}}{\partial \alpha_j^l} = \frac{\partial \mathcal{L}}{\partial m_j^l} \cdot \frac{\partial m_j^l}{\partial \alpha_j^l} \approx \frac{\partial \mathcal{L}}{\partial m_j^l} \cdot \mathbb{I}(|\alpha_j^l - \beta_l| < \epsilon) \cdot \gamma \cdot \sigma'(\gamma \cdot (\alpha_j^l - \beta_l)) ∂αjl∂L=∂mjl∂L⋅∂αjl∂mjl≈∂mjl∂L⋅I(∣αjl−βl∣<ϵ)⋅γ⋅σ′(γ⋅(αjl−βl))
资源约束项可以使用可微分资源估计函数:
R ( α ) = ∑ l = 1 L ∑ j = 1 C l m j l ( α ) ⋅ r j l R(\alpha) = \sum_{l=1}^L \sum_{j=1}^{C_l} m_j^l(\alpha) \cdot r_j^l R(α)=l=1∑Lj=1∑Clmjl(α)⋅rjl
其中 r j l r_j^l rjl是第 l l l层第 j j j个通道的资源占用(如FLOPs、参数量等)。
为处理多种资源约束,可以使用加权约束项:
L r e s ( α ) = ∑ i = 1 m λ i ⋅ ∣ R i ( α ) R i , t a r g e t − 1 ∣ \mathcal{L}_{res}(\alpha) = \sum_{i=1}^m \lambda_i \cdot \left| \frac{R_i(\alpha)}{R_{i,target}} - 1 \right| Lres(α)=i=1∑mλi⋅ Ri,targetRi(α)−1
知识蒸馏辅助剪枝
知识蒸馏可以辅助结构化剪枝,减轻精度损失:
L K D = ( 1 − β ) ⋅ L C E ( W ′ , D ) + β ⋅ L d i s t i l l ( W ′ , W , D ) \mathcal{L}_{KD} = (1-\beta) \cdot \mathcal{L}_{CE}(\mathbf{W}', \mathcal{D}) + \beta \cdot \mathcal{L}_{distill}(\mathbf{W}', \mathbf{W}, \mathcal{D}) LKD=(1−β)⋅LCE(W′,D)+β⋅Ldistill(W′,W,D)
其中 L C E \mathcal{L}_{CE} LCE是标准交叉熵损失, L d i s t i l l \mathcal{L}_{distill} Ldistill是蒸馏损失,基本形式可以表示为:
L d i s t i l l = τ 2 ⋅ KL ( σ ( z τ ) , σ ( z ′ τ ) ) \mathcal{L}_{distill} = \tau^2 \cdot \text{KL}\left(\sigma\left(\frac{\mathbf{z}}{\tau}\right), \sigma\left(\frac{\mathbf{z}'}{\tau}\right)\right) Ldistill=τ2⋅KL(σ(τz),σ(τz′))
这里 z \mathbf{z} z和 z ′ \mathbf{z}' z′分别是原始模型和剪枝模型的logits, σ \sigma σ是softmax函数, τ \tau τ是温度参数, KL \text{KL} KL是KL散度。
更一般地,KL散度可以表示为:
KL ( P ∥ Q ) = ∑ i P ( i ) log P ( i ) Q ( i ) = ∑ i P ( i ) log P ( i ) − ∑ i P ( i ) log Q ( i ) = − H ( P ) + H ( P , Q ) \text{KL}(P \| Q) = \sum_{i} P(i) \log \frac{P(i)}{Q(i)} = \sum_{i} P(i) \log P(i) - \sum_{i} P(i) \log Q(i) = -H(P) + H(P, Q) KL(P∥Q)=i∑P(i)logQ(i)P(i)=i∑P(i)logP(i)−i∑P(i)logQ(i)=−H(P)+H(P,Q)
其中 H ( P ) H(P) H(P)是分布 P P P的熵, H ( P , Q ) H(P, Q) H(P,Q)是 P P P和 Q Q Q的交叉熵。
在特征蒸馏中,我们可以使用特征图之间的匹配损失:
L f e a t = ∑ l ∈ S ∥ F l ∥ F l ∥ 2 − F l ′ ∥ F l ′ ∥ 2 ∥ 2 2 \mathcal{L}_{feat} = \sum_{l \in \mathcal{S}} \left\| \frac{\mathbf{F}_l}{\|\mathbf{F}_l\|_2} - \frac{\mathbf{F}_l'}{\|\mathbf{F}_l'\|_2} \right\|_2^2 Lfeat=l∈S∑ ∥Fl∥2Fl−∥Fl′∥2Fl′ 22
其中 F l \mathbf{F}_l Fl和 F l ′ \mathbf{F}_l' Fl′分别是原始模型和剪枝模型在第 l l l层的特征图, S \mathcal{S} S是选择用于蒸馏的层集合。
更一般地,我们可以使用Wasserstein距离度量特征分布差异:
L w a s s = ∑ l ∈ S W 2 ( F l , F l ′ ) \mathcal{L}_{wass} = \sum_{l \in \mathcal{S}} W_2(\mathbf{F}_l, \mathbf{F}_l') Lwass=l∈S∑W2(Fl,Fl′)
其中 W 2 W_2 W2是2阶Wasserstein距离,定义为:
W 2 ( μ , ν ) = inf γ ∈ Γ ( μ , ν ) ( ∫ ∥ x − y ∥ 2 2 d γ ( x , y ) ) 1 / 2 W_2(\mu, \nu) = \inf_{\gamma \in \Gamma(\mu, \nu)} \left( \int \|x-y\|_2^2 d\gamma(x,y) \right)^{1/2} W2(μ,ν)=γ∈Γ(μ,ν)inf(∫∥x−y∥22dγ(x,y))1/2
Γ ( μ , ν ) \Gamma(\mu, \nu) Γ(μ,ν)是所有边缘分布分别为 μ \mu μ和 ν \nu ν的联合分布集合。
结构化剪枝的完整算法流程
资源约束下的结构化剪枝的典型算法流程如下:
-
对预训练模型进行重要性评估: s j l = I ( W j l , X l , Y l ) s_j^l = I(\mathbf{W}_j^l, \mathbf{X}^l, \mathbf{Y}^l) sjl=I(Wjl,Xl,Yl)
-
根据资源约束 R t a r g e t R_{target} Rtarget,求解全局优化问题,确定每层的剪枝比例:
{ M l } l = 1 L = arg min { M l } ∑ l = 1 L ∑ j = 1 C l s j l ⋅ ( 1 − m j l ) s.t. R ( { M l } ) ≤ R t a r g e t \{\mathbf{M}^l\}_{l=1}^L = \arg\min_{\{\mathbf{M}^l\}} \sum_{l=1}^L \sum_{j=1}^{C_l} s_j^l \cdot (1-m_j^l) \quad \text{s.t.} \quad R(\{\mathbf{M}^l\}) \leq R_{target} {Ml}l=1L=arg{Ml}minl=1∑Lj=1∑Clsjl⋅(1−mjl)s.t.R({Ml})≤Rtarget -
应用结构掩码生成剪枝后的模型: W ′ = { M l ⊙ W l } l = 1 L \mathbf{W}' = \{\mathbf{M}^l \odot \mathbf{W}^l\}_{l=1}^L W′={Ml⊙Wl}l=1L
-
微调剪枝后的模型,可结合知识蒸馏:
min W ′ ( 1 − β ) ⋅ L C E ( W ′ , D ) + β ⋅ L d i s t i l l ( W ′ , W , D ) \min_{\mathbf{W}'} (1-\beta) \cdot \mathcal{L}_{CE}(\mathbf{W}', \mathcal{D}) + \beta \cdot \mathcal{L}_{distill}(\mathbf{W}', \mathbf{W}, \mathcal{D}) W′min(1−β)⋅LCE(W′,D)+β⋅Ldistill(W′,W,D)
剪枝后网络的理论加速比可以计算为:
Speedup = ∑ l = 1 L 2 ⋅ C o u t l ⋅ C i n l ⋅ K l 2 ⋅ H o u t l ⋅ W o u t l ∑ l = 1 L 2 ⋅ C o u t l ′ ⋅ C i n l ′ ⋅ K l 2 ⋅ H o u t l ⋅ W o u t l \text{Speedup} = \frac{\sum_{l=1}^L 2 \cdot C_{out}^l \cdot C_{in}^l \cdot K_l^2 \cdot H_{out}^l \cdot W_{out}^l}{\sum_{l=1}^L 2 \cdot C_{out}^{l'} \cdot C_{in}^{l'} \cdot K_l^2 \cdot H_{out}^l \cdot W_{out}^l} Speedup=∑l=1L2⋅Coutl′⋅Cinl′⋅Kl2⋅Houtl⋅Woutl∑l=1L2⋅Coutl⋅Cinl⋅Kl2⋅Houtl⋅Woutl
其中 C o u t l ′ C_{out}^{l'} Coutl′和 C i n l ′ C_{in}^{l'} Cinl′分别是剪枝后第 l l l层的输出和输入通道数。
剪枝算法的收敛性分析可以通过Lyapunov函数进行:定义函数 V ( W , M ) = L ( W , M , D ) + λ ⋅ R ( M ) V(\mathbf{W}, \mathbf{M}) = \mathcal{L}(\mathbf{W}, \mathbf{M}, \mathcal{D}) + \lambda \cdot R(\mathbf{M}) V(W,M)=L(W,M,D)+λ⋅R(M),如果能证明 V ( W k + 1 , M k + 1 ) ≤ V ( W k , M k ) V(\mathbf{W}^{k+1}, \mathbf{M}^{k+1}) \leq V(\mathbf{W}^{k}, \mathbf{M}^{k}) V(Wk+1,Mk+1)≤V(Wk,Mk)且 V V V有下界,则算法保证收敛。