工欲善其事,必先利其器,要写出好的MATLAB代码,先从最基础的代码开始做起。
1 ADMM原理
以下内容整理自Standford University的Boyd老师的课件和论文。
ADMM问题的基本形式:
最优化问题形式包括两组可分离自变量和线性等式约束:
min
x
,
z
f
(
x
)
+
g
(
z
)
s
.
t
.
A
x
+
B
z
=
c
\begin{align*} \min_{x,z}\quad&f(\mathbf{x})+g(\textbf{z})\\ s.t. \quad&\mathbf{Ax}+\mathbf{Bz} = \mathbf{c} \end{align*}
x,zmins.t.f(x)+g(z)Ax+Bz=c
写出该问题对应的拉格朗日函数式:
L
ρ
(
x
,
z
,
λ
)
=
f
(
x
)
+
g
(
z
)
+
y
T
(
A
x
+
B
z
−
c
)
+
ρ
2
∥
A
x
+
B
z
−
c
∥
2
2
L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\mathbf{y}^T(\mathbf{Ax}+\mathbf{Bz} - \mathbf{c})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}\Vert^2_2}
Lρ(x,z,λ)=f(x)+g(z)+yT(Ax+Bz−c)+2ρ∥Ax+Bz−c∥22
按如下步骤,按照Gauss-Seidel方法更新迭代:
x
k
+
1
=
a
r
g
m
i
n
x
L
ρ
(
x
,
z
k
,
y
k
)
z
k
+
1
=
a
r
g
m
i
n
z
L
ρ
(
x
k
+
1
,
z
,
y
k
)
y
k
+
1
=
y
k
+
ρ
(
A
x
k
+
1
+
B
z
k
+
1
−
c
)
\begin{align*} x^{k+1} &= \mathop{argmin}\limits_x L_{\rho}(\mathbf{x}, \mathbf{z}^k, \mathbf{y}^k)\\ z^{k+1} &= \mathop{argmin}\limits_z L_{\rho}(\mathbf{x}^{k+1}, \mathbf{z}, \mathbf{y}^k)\\ y^{k+1} &= y^k+\rho(\mathbf{Ax}^{k+1}+\mathbf{Bz}^{k+1}-c) \end{align*}
xk+1zk+1yk+1=xargminLρ(x,zk,yk)=zargminLρ(xk+1,z,yk)=yk+ρ(Axk+1+Bzk+1−c)
进一步,如在拉格朗日函数中定义
u
k
=
(
1
/
ρ
)
y
k
\mathbf{u}^k = (1/\rho)\mathbf{y}^k
uk=(1/ρ)yk,放缩对偶变量,拉格朗日函数变为:
L
ρ
(
x
,
z
,
λ
)
=
f
(
x
)
+
g
(
z
)
+
ρ
2
∥
A
x
+
B
z
−
c
+
u
∥
2
2
+
c
o
n
s
t
L_{\rho}(\mathbf{x}, \mathbf{z}, \mathbf{\lambda}) =f(\mathbf{x})+g(\textbf{z})+\frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz} - \mathbf{c}+\mathbf{u}\Vert^2_2+const}
Lρ(x,z,λ)=f(x)+g(z)+2ρ∥Ax+Bz−c+u∥22+const
相对应的迭代式子变为:
x
k
+
1
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
A
x
+
B
z
k
−
c
+
u
k
∥
2
2
z
k
+
1
=
a
r
g
m
i
n
z
g
(
z
)
+
ρ
2
∥
A
x
k
+
1
+
B
z
−
c
+
u
k
∥
2
2
u
k
+
1
=
u
k
+
A
x
k
+
1
+
B
z
k
+
1
−
c
\begin{align*} \mathbf{x}^{k+1} &= \mathop{argmin}\limits_x f(\mathbf{x})+ \frac{\rho}{2}{\Vert\mathbf{Ax}+\mathbf{Bz}^k - \mathbf{c}+\mathbf{u}^k\Vert^2_2}\\ \mathbf{z}^{k+1} &= \mathop{argmin}\limits_z g(\mathbf{z})+ \frac{\rho}{2}{\Vert\mathbf{Ax}^{k+1}+\mathbf{Bz} - \mathbf{c}+\mathbf{u}^k\Vert^2_2}\\ \mathbf{u}^{k+1} &= \mathbf{u}^k+\mathbf{Ax}^{k+1}+\mathbf{Bz}^{k+1}-c \end{align*}
xk+1zk+1uk+1=xargminf(x)+2ρ∥Ax+Bzk−c+uk∥22=zargming(z)+2ρ∥Axk+1+Bz−c+uk∥22=uk+Axk+1+Bzk+1−c
到此为止,上述仍然是一些抽象的概念,与实际问题并无什么联系。好在Boyd老师在网站上面挂出来了一些案例,但是这些案例难懂,我费了一些功夫终于看明白其中一二。
2 部分特殊表示
有些符号看得不是很清楚,总结于以下
邻近算子( Proximity Operator \text{Proximity Operator} Proximity Operator)
前面已经知道
x
x
x的更新值为
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
A
x
+
B
z
k
−
c
+
u
k
∥
2
2
{x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}+{Bz}^k - {c}+{u}^k\Vert^2_2}
x+=xargminf(x)+2ρ∥Ax+Bzk−c+uk∥22
令
v
=
−
B
z
+
c
−
u
v = -Bz+c-u
v=−Bz+c−u,
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
A
x
−
v
∥
2
2
{x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{Ax}-v\Vert^2_2}
x+=xargminf(x)+2ρ∥Ax−v∥22
又令
A
=
I
A=I
A=I:
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
x
−
v
∥
2
2
(*)
\begin{equation} {x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{x}-v\Vert^2_2}\end{equation}\tag{*}
x+=xargminf(x)+2ρ∥x−v∥22(*)
该式子的右端项可以用
prox
f
,
ρ
(
v
)
\textbf{prox}_{f,\rho}(v)
proxf,ρ(v)来表示,我暂时将其翻译为:函数
f
f
f带惩罚项
ρ
\rho
ρ的邻近算子。所以,后面有一些式子会用式子(*)来简略表示。
向集合的投影( Projection \text{Projection} Projection)
当函数
f
f
f足够简单,
x
x
x的更新值成为前面所说到的邻近算子的形式,可以用解析的办法分析。其中一个例子是,假如说
f
f
f是一个非空闭凸集的指示函数(
indicator function
\text{indicator function}
indicator function),那么也可以将x的更新值表示为:
x
+
=
a
r
g
m
i
n
x
f
(
x
)
+
ρ
2
∥
x
−
v
∥
2
2
=
Π
C
(
v
)
{x}^{+} = \mathop{argmin}\limits_x f({x})+ \frac{\rho}{2}{\Vert{x}-v\Vert^2_2}=\Pi_\mathcal{C}(v)
x+=xargminf(x)+2ρ∥x−v∥22=ΠC(v)
其中,
Π
C
\Pi_\mathcal{C}
ΠC表示向
C
\mathcal{C}
C上的欧式范数的投影.
软阈值( Soft Thresholding \text{Soft Thresholding} Soft Thresholding)
感谢前辈的文章1,2,基本了解软阈值的情况。
软阈值问题的形式为:
S
κ
(
a
)
=
{
a
−
κ
a
>
κ
0
∣
a
∣
≤
κ
a
+
κ
a
<
−
κ
=
(
a
−
κ
)
+
−
(
−
a
−
κ
)
+
S_\kappa(a) =\begin{cases} a-\kappa&a>\kappa\\ 0&|a|\leq\kappa\\ a+\kappa&a<-\kappa \end{cases}=(a-\kappa)_+-(-a-\kappa)_+
Sκ(a)=⎩
⎨
⎧a−κ0a+κa>κ∣a∣≤κa<−κ=(a−κ)+−(−a−κ)+
是在求解优化问题形如:
a
r
g
m
i
n
x
∥
x
−
B
∥
2
2
+
λ
∥
x
∥
1
\mathop{argmin}\limits_x\Vert x-B\Vert^2_2+\lambda\Vert x\Vert_1
xargmin∥x−B∥22+λ∥x∥1
的时候作用的,用于标记
x
x
x更新值的取值。其中
B
=
[
b
1
,
b
2
,
…
,
b
n
]
B = [b_1, b_2, \dots, b_n]
B=[b1,b2,…,bn]。由于
∥
x
∥
1
\Vert x\Vert_1
∥x∥1并不可微,所以通过分类讨论的结果:
S
λ
/
2
(
b
)
=
{
b
−
λ
2
b
>
λ
2
0
∣
b
∣
<
λ
2
b
+
λ
2
b
<
−
λ
2
=
(
b
−
λ
2
)
+
−
(
−
b
−
λ
2
)
+
S_{\lambda/2}(b) =\begin{cases} b-\frac{\lambda}{2}&b>\frac{\lambda}{2}\\ 0&|b|<\frac{\lambda}{2}\\ b+\frac{\lambda}{2}&b<-\frac{\lambda}{2} \end{cases}=(b-\frac{\lambda}{2})_+-(-b-\frac{\lambda}{2})_+
Sλ/2(b)=⎩
⎨
⎧b−2λ0b+2λb>2λ∣b∣<2λb<−2λ=(b−2λ)+−(−b−2λ)+
例如,在Boyd老师的论文中就提到,
x
i
x_i
xi更新值:
x
i
+
=
a
r
g
m
i
n
x
i
(
λ
∣
x
i
∣
+
(
ρ
/
2
)
(
x
i
−
v
i
)
2
)
.
x_i^+=\mathop{argmin}\limits_{x_i}(\lambda|x_i|+(\rho/2)(x_i-v_i)^2).
xi+=xiargmin(λ∣xi∣+(ρ/2)(xi−vi)2).
可以得到相应更新值:
x
i
+
:
=
S
λ
/
ρ
(
v
i
)
x_i^+:=S_{\lambda/\rho}(v_i)
xi+:=Sλ/ρ(vi)
3 部分代码解析
以下选择Lasso问题的代码进行分析。代码摘自网站,此处只是复刻编程的流程,整理一下思路。
3.1 代码架构
代码分为lasso.m, objective.m, shrinkage.m, factor.m等,在运行中实际起到作用分别是:
- function lasso.m是整个ADMM的算法执行流程,包括数据记录,数据迭代,数据输出,迭代终点判断等内容;
- objective.m是整个优化函数的目标函数;
- shrinkage.m表示整个优化函数的目标函数;
- factor.m则是根据A的形状的不同,进行的分解。
- 通过实际案例检验所写的代码是否正确反映了算法。
3.2 代码解构
3.2.1 lasso.m
lasso问题的目标函数为:
m
i
n
i
m
i
z
e
1
2
∥
A
x
−
b
∥
2
2
+
λ
∥
x
∥
1
minimize \quad \frac{1}{2}\Vert Ax-b\Vert^2_2+\lambda\Vert x\Vert_1
minimize21∥Ax−b∥22+λ∥x∥1
写成ADMM方法所能够求解的格式:
m
i
n
i
m
i
z
e
1
2
∥
A
x
−
b
∥
2
2
+
λ
∥
z
∥
1
subject to
x
−
z
=
0
\begin{align*} minimize\quad &\frac{1}{2}\Vert Ax-b\Vert^2_2+\lambda\Vert z\Vert_1\\ \text{subject to}\quad &x-z = 0 \end{align*}
minimizesubject to21∥Ax−b∥22+λ∥z∥1x−z=0
迭代的表达式为:
x
k
+
1
:
=
(
A
T
A
+
ρ
I
)
−
1
(
A
T
b
+
ρ
(
z
k
−
u
k
)
)
z
k
+
1
:
=
S
λ
/
ρ
(
x
k
+
1
+
u
k
)
u
k
+
1
:
=
u
k
+
x
k
+
1
−
z
k
+
1
\begin{align*} x^{k+1}&:=(A^TA+\rho I)^{-1}(A^Tb+\rho(z^k-u^k))\\ z^{k+1}&:=S_{\lambda/\rho}(x^{k+1}+u^k)\\ u^{k+1}&:=u^k+x^{k+1}-z^{k+1} \end{align*}
xk+1zk+1uk+1:=(ATA+ρI)−1(ATb+ρ(zk−uk)):=Sλ/ρ(xk+1+uk):=uk+xk+1−zk+1
function [z, history] = lasso(A, b, lambda, rho, alpha)
% lasso Solve lasso problem via ADMM
% [z, history] = lasso(A, b, lambda, rho, alpha);
% Solves the following problem via ADMM:
% minimize 1/2*|| Ax - b ||_2^2 + \lambda || x ||_1
% The solution is returned in the vector x.
% history is a structure that contains the objective value, the primal and
% dual residual norms, and the tolerances for the primal and dual residual
% norms at each iteration.
% rho is the augmented Lagrangian parameter.
% alpha is the over-relaxation parameter (typical values for alpha are
% between 1.0 and 1.8).
% More information can be found in the paper linked at
%:http://www.stanford.edu/~boyd/papers/distr_opt_stat_learning_admm.html
%
t_start = tic;
Global constants and defaults
QUIET = 0;
MAX_ITER = 1000;
ABSTOL = 1e-4;
RELTOL = 1e-2;
Data preprocessing
[m, n] = size(A);
% save a matrix-vector multiply
Atb = A'*b;
ADMM solver
x = zeros(n,1);
z = zeros(n,1);
u = zeros(n,1);
% cache the factorization
[L U] = factor(A, rho);
if ~QUIET
fprintf('%3s\t%10s\t%10s\t%10s\t%10s\t%10s\n', 'iter', ...
'r norm', 'eps pri', 's norm', 'eps dual', 'objective');
end
for k = 1:MAX_ITER
% x-update
q = Atb + rho*(z - u); % temporary value
if( m >= n ) % if skinny
x = U \ (L \ q);
else % if fat
x = q/rho - (A'*(U \ ( L \ (A*q) )))/rho^2;
end
% z-update with relaxation
zold = z;
x_hat = alpha*x + (1 - alpha)*zold;
z = shrinkage(x_hat + u, lambda/rho);
% u-update
u = u + (x_hat - z);
% diagnostics, reporting, termination checks
history.objval(k) = objective(A, b, lambda, x, z);
history.r_norm(k) = norm(x - z);
history.s_norm(k) = norm(-rho*(z - zold));
history.eps_pri(k) = sqrt(n)*ABSTOL + RELTOL*max(norm(x), norm(-z));
history.eps_dual(k)= sqrt(n)*ABSTOL + RELTOL*norm(rho*u);
if ~QUIET
fprintf('%3d\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.2f\n', k, ...
history.r_norm(k), history.eps_pri(k), ...
history.s_norm(k), history.eps_dual(k), history.objval(k));
end
if (history.r_norm(k) < history.eps_pri(k) && ...
history.s_norm(k) < history.eps_dual(k))
break;
end
end
if ~QUIET
toc(t_start);
end
end
3.2.2 objective.m
function p = objective(A, b, lambda, x, z)
p = ( 1/2*sum((A*x - b).^2) + lambda*norm(z,1) );
end
3.2.3 shrinkage.m
此处表达的是式子:
z
k
+
1
:
=
S
λ
/
ρ
(
x
k
+
1
+
u
k
)
=
(
x
k
+
1
+
u
k
−
λ
/
ρ
)
+
−
(
−
x
k
+
1
−
u
k
−
λ
/
ρ
)
z^{k+1}:=S_{\lambda/\rho}(x^{k+1}+u^k)=(x^{k+1}+u^k-\lambda/\rho)_+-(-x^{k+1}-u^k-\lambda/\rho)
zk+1:=Sλ/ρ(xk+1+uk)=(xk+1+uk−λ/ρ)+−(−xk+1−uk−λ/ρ)
function z = shrinkage(x, kappa)
z = max( 0, x - kappa ) - max( 0, -x - kappa );
end
3.2.4 factor.m
A A A是一个 m × n m\times n m×n的矩阵。当 m < n m<n m<n,由于在 x x x的更新式中, A T A + ρ I A^TA+\rho I ATA+ρI是一个 n × n n\times n n×n的矩阵,而提一个 1 / ρ 1/\rho 1/ρ后,变为 I + ( 1 / ρ ) A A T I+(1/\rho)AA^T I+(1/ρ)AAT,是一个 m × m m\times m m×m大小的矩阵。由于矩阵求逆的复杂度为 O ( n 3 ) O(n^3) O(n3),因此,后面这种做法更便于求解。结合稀疏矢量技术完成文中代码。
function [L U] = factor(A, rho)
[m, n] = size(A);
if ( m >= n ) % if skinny
L = chol( A'*A + rho*speye(n), 'lower' );
else % if fat
L = chol( speye(m) + 1/rho*(A*A'), 'lower' );
end
% force matlab to recognize the upper / lower triangular structure
L = sparse(L);
U = sparse(L');
end
3.2.5 难点和实施
难点主要在于要矩阵表达式的运算和推导,容易出错,因此给复现带来了困难。实际案例(Example)见于链接,为上面函数的应用。