交替方向乘子法(ADMM)的数学基础

交替方向乘子法(ADMM)

网上的一些资料根本就没有把ADMM的来龙去脉说清楚,发现只是一个地方简单写了一下流程,别的地方就各种抄,共轭函数,对偶梯度上升什么的,都没讲清楚,给跪了。下面我来讲讲在机器学习中用得很多的ADMM方法到底是何方神圣。

共轭函数

给定函数 f : R n → R f: \mathbb{R}^{n} \rightarrow \mathbb{R} f:RnR,那么函数
f ∗ ( y ) = max ⁡ x ( y T x − f ( x ) ) f^{*}(y)=\max _{x}( y^{T} x-f(x)) f(y)=xmax(yTxf(x))
就叫做它的共轭函数。其实一个更直观的理解是:对一个固定的 y y y,将 y T x y^Tx yTx看成是一条斜率为 y y y的直线,它和 f ( x ) f(x) f(x)关于 x x x的距离的最大值,就是 f ∗ ( y ) f^*(y) f(y)百度百科有关于这个直观说法的一个解释,看看就明白了。

关于共轭函数有几点重要的说明:

  • 不管 f f f凸不凸,它的共轭函数总是一个凸函数。

  • 如果 f f f是闭凸的(闭指定义域是闭的),那么 f ∗ ∗ = f f^{**}=f f=f

  • 如果 f f f是严格凸的,那么
    ∇ f ∗ ( y ) = argmin ⁡ z ( f ( z ) − y T z ) \nabla f^{*}(y)=\underset{z}{\operatorname{argmin}} (f(z)-y^{T} z) f(y)=zargmin(f(z)yTz)

  • 共轭总是频频出现在对偶规划中,因为极小问题总是容易凑出一个共轭: − f ∗ ( y ) = min ⁡ x ( f ( x ) − y T x ) -f^{*}(y)=\min _{x} (f(x)-y^{T} x) f(y)=minx(f(x)yTx)

关于 f ∗ f^* f的凸性,这篇博客给了一个比较直观的图示说明,下面我从数学上,不太严格地做个简单证明。以一维的情况说明。

假设 f f f是一个凸函数,下面都不妨考虑函数的最值都不再边界处取到。 max ⁡ x ( y x − f ( x ) ) \max _{x} (yx-f(x)) maxx(yxf(x))的极值点在 f x ′ ( x ) = y f'_x(x)=y fx(x)=y处取到,定义 g : = ( f x ′ ) − 1 g:=(f'_x)^{-1} g:=(fx)1,那么 x = g ( y ) x=g(y) x=g(y)可能会是一堆点。则有
f ∗ ( y ) = max ⁡ x ( y x − f ( x ) ) = y g ( x ) − f ( g ( y ) ) f^*(y)=\max _{x}(yx-f(x))=yg(x)-f(g(y)) f(y)=xmax(yxf(x))=yg(x)f(g(y)) 进而
( f ∗ ( y ) ) y ′ = g ( y ) + y g y ′ ( y ) − f x ′ ( g ( y ) ) g y ′ ( y ) = g ( y ) (f^*(y))'_y = g(y)+yg'_y(y)-f'_x(g(y))g'_y(y) = g(y) (f(y))y=g(y)+ygy(y)fx(g(y))gy(y)=g(y) 那么
( f ∗ ( y ) ) y ′ ′ y = g y ′ ( y ) = ( ( f x ′ ) − 1 ) y ′ ≥ 0 (f^*(y))''_yy=g'_y(y)=((f'_x)^{-1})'_y \geq 0 (f(y))yy=gy(y)=((fx)1)y0 故而, f ∗ f^* f是凸的。

关于 f ∗ ∗ = f f^{**}=f f=f,也是容易证明的。我们假设 f f f是闭凸的。 max ⁡ y ( z y − f ∗ ( y ) ) \max _y(zy-f^*(y)) maxy(zyf(y))的值在 g ( y ) = z g(y)=z g(y)=z处取到,那么
f ∗ ∗ ( z ) = max ⁡ y ( z y − f ∗ ( y ) ) = f ( g ( g − 1 ( z ) ) ) = f ( z ) f^{**}(z) = \max_y(zy-f^{*}(y))=f(g(g^{-1}(z)))=f(z) f(z)=ymax(zyf(y))=f(g(g1(z)))=f(z)
z z z换成 x x x就是 f ( x ) f(x) f(x)

第三条非常重要,它说明了共轭函数的梯度,其实就是共轭函数取到极大值对应的 x x x值,它从 ( f ∗ ( y ) ) y ′ = g ( y ) (f*(y))'_y = g(y) (f(y))y=g(y)就可以看出来。

对偶梯度上升法

有了上知识的铺垫,我们就可以说清楚对偶上升方法了。以考虑等式约束问题为例(一般约束问题也是类似的流程),假设 f ( x ) f(x) f(x)是严格凸的,我们考虑问题:
min ⁡ x f ( x )  subject to  A x = b \min _{x} f(x) \text { subject to } A x=b xminf(x) subject to Ax=b 它的拉格朗日对偶问题是:
max ⁡ u min ⁡ x ( f ( x ) + u T ( A x − b ) ) \max _{u}\min _{x} (f(x)+u^T(Ax-b)) umaxxmin(f(x)+uT(Axb))
有理论表明,若原问题和对偶问题满足强对偶条件,即对偶函数关于 u u u的最大值等价于原优化问题关于 x x x的最小。那么原问题和对偶问题对于 x x x是同解的。也就是说只要找到使得对偶问题对应最大的 u u u,其对应的 x x x就是原优化问题的解,那么我们就解决了原始优化问题。

所以,下面我们来求解这个对偶问题。先把和 x x x无关的变量提出 min ⁡ x \min _x minx,再想办法凑出 f ∗ f^* f,因为我们要用到对偶的性质。
KaTeX parse error: No such environment: split at position 7: \begin{̲s̲p̲l̲i̲t̲}̲ \max _u \min…

那么对偶问题就成了 max ⁡ u − f ∗ ( − A T u ) − b T u \max _{u}-f^{*}\left(-A^{T} u\right)-b^{T} u umaxf(ATu)bTu
这里 f ∗ f^* f f f f的共轭,这里 max ⁡ u \max _u maxu后面不加括号,表示它管着下面的所有,下同,不再重述。定义 g ( u ) = − f ∗ ( − A T u ) − b T u g(u)=-f^{*}\left(-A^{T} u\right)-b^{T} u g(u)=f(ATu)bTu,我们希望能极大化 g ( u ) g(u) g(u),一个简单的想法是沿着 g ( u ) g(u) g(u)梯度上升的方向去走。注意到,
∂ g ( u ) = A ∂ f ∗ ( − A T u ) − b \partial g(u)=A \partial f^{*}\left(-A^{T} u\right)-b g(u)=Af(ATu)b
因此,利用共轭的性质,
∂ g ( u ) = A x − b  where  x ∈ argmin ⁡ z f ( z ) + u T A z \partial g(u)=A x-b \text { where } x \in \underset{z}{\operatorname{argmin}} f(z)+u^{T} A z g(u)=Axb where xzargminf(z)+uTAz
因为 f f f是严格凸的, f ∗ f^* f是可微的,那么,就有了所谓的对偶梯度上升方法。从一个对偶初值 u ( 0 ) u^{(0)} u(0)开始,重复以下过程:
x ( k ) = argmin ⁡ x f ( x ) + ( u ( k − 1 ) ) T A x u ( k ) = u ( k − 1 ) + t k ( A x ( k ) − b ) \begin{aligned} &x^{(k)}=\underset{x}{\operatorname{argmin}} f(x)+\left(u^{(k-1)}\right)^{T} A x\\ &u^{(k)}=u^{(k-1)}+t_{k}\left(A x^{(k)}-b\right) \end{aligned} x(k)=xargminf(x)+(u(k1))TAxu(k)=u(k1)+tk(Ax(k)b)
这里的步长 t k t_k tk使用标准的方式选取的。近端梯度和加速可以应用到这个过程中进行优化。

交替方向乘子法

交替方向乘子法(ADMM)是一种求解具有可分离的凸优化问题的重要方法,由于处理速度快,收敛性能好,ADMM算法在统计学习、机器学习等领域有着广泛应用。ADMM算法一般用于解决如下的凸优化问题:
min ⁡ x , y f ( x ) + g ( y )  subject to  A x + B y = c \min _{x, y} f(x)+g(y) \text { subject to } A x+B y=c x,yminf(x)+g(y) subject to Ax+By=c
其中的 f f f g g g都是凸函数。

它的增广拉格朗日函数如下:
L p ( x , y , λ ) = f ( x ) + g ( y ) + λ T ( A x + B y − c ) + ( ρ / 2 ) ∥ A x + B y − c ∥ 2 2 , ρ > 0 L_{p}(x, y, \lambda)=f(x)+g(y)+\lambda^{T}(A x+B y-c)+(\rho / 2)\|A x+B y-c\|_{2}^{2}, \rho>0 Lp(x,y,λ)=f(x)+g(y)+λT(Ax+Byc)+(ρ/2)Ax+Byc22,ρ>0

ADMM算法求解思想和推导同梯度上升法,最后重复迭代以下过程:
x k + 1 : = arg ⁡ min ⁡ x L p ( x , y , λ ) x k + 1 : = arg ⁡ min ⁡ y L p ( x , y , λ ) λ k + 1 : = λ k + ρ ( A x k + 1 + B y k + 1 − c ) \begin{aligned} x^{k+1} &:=\arg \min _x L_{p}(x, y, \lambda) \\ x^{k+1} &:=\arg \min _y L_{p}(x, y, \lambda) \\ \lambda^{k+1} &:=\lambda^{k}+\rho\left(A x^{k+1}+B y^{k+1}-c\right) \end{aligned} xk+1xk+1λk+1:=argxminLp(x,y,λ):=argyminLp(x,y,λ):=λk+ρ(Axk+1+Byk+1c) 上述迭代可以进行简化。

  • 第一步简化,通过公式 ∥ a + b ∥ 2 2 = ∥ a ∥ 2 2 + ∥ b ∥ 2 2 + 2 a T b \|a+b\|_{2}^{2}=\|a\|_{2}^{2}+\|b\|_{2}^{2}+2 a^{T} b a+b22=a22+b22+2aTb,替换掉拉格朗日函数中的线性项 λ T ( A x + B y − c ) \lambda^{T}(A x+B y-c) λT(Ax+Byc)和二次项 ρ / 2 ∥ A x + B y − c ∥ 2 2 \rho/2\|A x+B y-c\|_{2}^{2} ρ/2Ax+Byc22,可以得到
    λ T ( A x + B y − c ) + ρ / 2 ∥ A x + B y − c ∥ 2 2 = ρ / 2 ∥ A x + B y − c + λ / ρ ∥ 2 2 − ρ / 2 ∥ λ / ρ ∥ 2 2 \lambda^{T}(A x+B y-c)+\rho/2\|A x+B y-c\|_{2}^{2}=\rho / 2\|A x+B y-c+\lambda/\rho\|_{2}^{2}-\rho / 2\|\lambda / \rho\|_{2}^{2} λT(Ax+Byc)+ρ/2Ax+Byc22=ρ/2Ax+Byc+λ/ρ22ρ/2λ/ρ22
    那么ADMM的过程可以化简如下: x k + 1 : = arg ⁡ min ⁡ x ( f ( x ) + ρ / 2 ∥ A x + B y k − c + λ k / ρ ∥ 2 2 y k + 1 : = arg ⁡ min ⁡ y ( g ( y ) + ρ / 2 ∥ A x k + 1 + B y − c + λ k / ρ ∥ 2 2 λ k + 1 : = λ k + ρ ( A x k + 1 + B y k + 1 − c ) \begin{aligned} x^{k+1} &:={\arg \min _x}\left(f(x)+\rho / 2\left\|A x+B y^{k}-c+\lambda^{k} / \rho\right\|_{2}^{2}\right.\\ y^{k+1} &:={\arg \min _y}\left(g(y)+\rho / 2\left\|A x^{k+1}+B y-c+\lambda^{k} / \rho\right\|_{2}^{2}\right.\\ \lambda^{k+1} &:=\lambda^{k}+\rho\left(A x^{k+1}+B y^{k+1}-c\right) \end{aligned} xk+1yk+1λk+1:=argxmin(f(x)+ρ/2Ax+Bykc+λk/ρ22:=argymin(g(y)+ρ/2Axk+1+Byc+λk/ρ22:=λk+ρ(Axk+1+Byk+1c)

  • 第二步化简,零缩放对偶变量 u = λ / ρ u = \lambda/\rho u=λ/ρ,于是ADMM过程可化简为:
    x k + 1 : = arg ⁡ min ⁡ ( f ( x ) + ρ / 2 ∥ A x + B y k − c + u k ∥ 2 2 y k + 1 : = arg ⁡ min ⁡ ( g ( y ) + ρ / 2 ∥ A x k + 1 + B y − c + u k ∥ 2 2 u k + 1 : = u k + ( A x k + 1 + B y k + 1 − c ) \begin{aligned} x^{k+1} &:={\arg \min}\left(f(x)+\rho / 2\left\|A x+B y^{k}-c+u^{k}\right\|_{2}^{2}\right.\\ y^{k+1} &:={\arg \min}\left(g(y)+\rho / 2\left\|A x^{k+1}+B y-c+u^{k}\right\|_{2}^{2}\right.\\ u^{k+1} &:=u^{k}+\left(A x^{k+1}+B y^{k+1}-c\right) \end{aligned} xk+1yk+1uk+1:=argmin(f(x)+ρ/2Ax+Bykc+uk22:=argmin(g(y)+ρ/2Axk+1+Byc+uk22:=uk+(Axk+1+Byk+1c)

ADMM相当于把一个大的问题分成了两个子问题,缩小了问题的规模,分而治之。

©️2020 CSDN 皮肤主题: 代码科技 设计师:Amelia_0503 返回首页