##Abstract
在DeepLearning日益发展的同时,模型压缩的关注度也越来越大。继BWN和TWN之后, 这篇论文在超低比特量化领域的一篇新思想文章,发表在AAAI2018上,作者是阿里的。
该文的主要思想是将超低bit quantization建模成一个离散约束优化问题(discretely constrained optimization problem)。借助ADMM(Alternating Direction Method of Mutipliers)思想,将连续参数从网络的离散约束中分离出来,并将原来的难题转化为几个子问题。该文针对这些子问题采用了extragradient以及iterative quantization算法(比传统算法更快收敛)。实验在图像分类以及目标检测(SSD)上都比最新的方法要好。
解决了什么?为什么写这篇文章?
- 之前的工作将pretrained weights量化到4-12bit效果很好,但是当只用1bit或者2bit表示weights时,只是在小datasets(MNIST、CIFAR10)上表现很好,在大datasets通常会产生很大的loss。
- 该文提出了一个独特的strategy,将超低bit weights作为一个离散约束的非凸优化问题(MIP)。借助ADMM算法,本文主要思想是使用离散空间中的辅助变量将离散约束中的连续变量解耦。不像之前量化方法修改连续参数中的特定梯度,本文同时优化连续和离散空间,使用增广拉格朗日函数连接了两个空间的解。
ADMM算法
主要参考这篇博客http://mullover.me/2016/01/19/admm-for-distributed-statistical-learning/
若优化问题可以表示为
m i n f ( x ) + g ( z ) s . t . A x + B z = c ( 1 ) min\quad f(x)+g(z) \quad s.t.\quad Ax+Bz=c\quad (1) minf(x)+g(z)s.t.Ax+Bz=c(1)
其中 x ∈ R s , z ∈ R n , A ∈ R p × n , B ∈ R p × m    a n d    c ∈ R p . x\in \mathbb R^s,z\in \mathbb R^n, A\in \mathbb R^{p \times n}, B\in \mathbb R^{p \times m} \; and \; c\in \mathbb R^p. x∈Rs,z∈Rn,A∈Rp×n,B∈Rp×mandc∈Rp. x x x与 z z z是优化变量, f ( x ) + g ( z ) f(x)+g(z) f(x)+g(z)是目标函数, A x + B z = c Ax+Bz=c Ax+Bz=c是等式约束条件。
等式 ( 1 ) (1) (1)的增广拉格朗日函数(augmented Lagrangian)可以表示为:
L ρ ( x , z , y ) = f ( x ) + g ( z ) + y T ( A x + B z − c ) + ( ρ / 2 ) ∥ A x + B z − c ∥ 2 2 ( 2 ) L_\rho(x,z,y)=f(x)+g(z)+y^T(Ax+Bz-c)+(\rho/2)\Vert Ax+Bz-c \Vert^2_2 \quad (2) Lρ(x,z,y)=f(x)+g(z)+yT(Ax+Bz−c)+(ρ/2)∥Ax+Bz−c∥22(2)
其中 y y y是拉格朗日乘子, ρ > 0 \rho>0 ρ>0是惩罚参数。增广就是由于加入了二次惩罚项。
则ADMM由三步迭代组成:
x k + 1 : = a r g min x L ρ ( x , z k , y k ) x^{k+1}:=arg\min\limits_xL_\rho(x,z^k,y^k) xk+1:=argxminLρ(x,zk,yk)
z k + 1 : = a r g min z L ρ ( x k + 1 , z , y k ) z^{k+1}:=arg\min\limits_zL_\rho(x^{k+1},z,y^k) zk+1:=argzminLρ(xk+1,z,yk)
y k + 1 : = y k + ρ ( A x k + 1 + B z k + 1 − c ) y^{k+1}:=y^k+\rho(Ax^{k+1}+Bz^{k+1}-c) yk+1:=yk+ρ(Axk+1+Bzk+1−c)
可以看出,每次迭代分为三步:
1.求解与 x x x相关的最小化问题,更新变量 x x x
2.求解与 x x x相关的最小化问题,更新变量 x x x
3.更新 ρ \rho ρ
Objective function
记作 f ( W ) f(W) f(W)为一个NN的loss fuction, W = W= W={ W 1 , W 2 , . . . , W L W_1,W_2,...,W_L W