lternating Direction Method of Multiplier(ADMM) Algorithm

Alternating Direction Method of Multipliers (ADMM) 是一种通过将凸优化问题分解为一系列的易解子问题进行求解的算法,目前它在很多领域得到了广泛的应用。 [2].

This is simplified version, specifically for the LASSO:

给定一个稀疏向量 x ∈ R n x\in R^n xRn和矩阵 A ∈ R m × n A\in R^{m\times n} ARm×n
y = A x + e y=Ax+e y=Ax+e
其中 e e e是加性高斯白噪声。为了恢复信号 x x x,我们求解如下最小化问题
x ^ = min ⁡ x ∣ ∣ y − A x ∣ ∣ 2 2 + λ ∣ ∣ x ∣ ∣ 1 \hat{x} = \min_x ||y-Ax||_2^2 + \lambda||x||_1 x^=xminyAx22+λx1


在求解过程中,迭代地计算如下两个式子,直到满足收敛条件。
x k + 1 = ( A T A + ρ I ) − 1 ( A T y + ρ ( z − u ) ) x^{k+1} = (A^TA + \rho I )^{-1}(A^Ty + \rho (z - u)) xk+1=(ATA+ρI)1(ATy+ρ(zu))
z k + 1 = s i g n ( x ^ ) m a x ( 0 , ∣ x ∣ − λ ρ ) z^{k+1} = \mathrm{sign}(\hat{x})\mathrm{max}\left(0, |x| - \frac{\lambda}{\rho}\right) zk+1=sign(x^)max(0,xρλ)

下面是ADMM算法的PYTHON实现方式。 (http://stanford.edu/~boyd/admm.html).

import numpy as np
import matplotlib.pyplot as plt
from math import sqrt, log

def Sthresh(x, gamma):
    return np.sign(x)*np.maximum(0, np.absolute(x)-gamma/2.0)

def ADMM(A, y):

    m, n = A.shape
    w, v = np.linalg.eig(A.T.dot(A))
    MAX_ITER = 10000

    # Function to caluculate min 1/2(y - Ax) + l||x||
    # via alternating direction methods
    xhat = np.zeros([n, 1])
    zhat = np.zeros([n, 1])
    u = np.zeros([n, 1])

    # Calculate regression co-efficient and stepsize
    lamb = sqrt(2*log(n, 10))
    rho = 1/(np.amax(np.absolute(w)))

    # Pre-compute to save some multiplications
    AtA = A.T.dot(A)
    Aty = A.T.dot(y)
    Q = AtA + rho*np.identity(n)
    Q = np.linalg.inv(Q)

    for i in np.arange(1, MAX_ITER + 1):

        # x minimisation step via posterier OLS
        xhat = Q.dot(Aty + rho*(zhat - u))

        # z minimisation via soft-thresholding
        zhat = Sthresh(xhat + u, lamb/rho)

        # mulitplier update
        u = u + xhat - zhat

    return zhat, rho, lamb

def test(m=50, n=200):
    """Test the ADMM method with randomly generated matrices and vectors"""
    A = np.random.randn(m, n)

    num_non_zeros = 10
    positions = np.random.randint(0, n, num_non_zeros)
    amplitudes = 100*np.random.randn(num_non_zeros, 1)
    x = np.zeros((n, 1))
    x[positions] = amplitudes

    y = A.dot(x) + np.random.randn(m, 1)

    xhat, rho, lamb = ADMM(A, y)

    plt.plot(x, label='Original')
    plt.plot(xhat, label = 'Estimate')

    plt.legend(loc = 'upper right')

    plt.show()


if __name__ == "__main__":
    test()

参考文献:
[1] https://codereview.stackexchange.com/questions/108263/alternating-direction-method-of-multipliers
[2] http://stanford.edu/~boyd/admm.html

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值