注:本内容参考:ADMM优化算法(附MATLAB代码)
文章目录
-
- ADMM原理
- ADMM求解-算法实例
- 算法Matlab代码及注释
- 小结
ADMM原理
ADMM求解-算法实例
下面给出一个二次规划凸优化问题,采用ADMM算法求解的示例。问题的优化模型如下:
1、首先,构造目标函数的增广拉格朗日函数(ALM):
2、基于ALM我们可以构造交替优化的求解形式:
算法Matlab代码及注释
1、主函数:main.m
%% 定义参数
% x0,y0都是可行解
param.x0 = 1;
param.y0 = 1;
param.lambda = 1;
param.maxIter = 30;
param.beta = 1.1; % a constant
param.rho = 0.5;
[Hx,Fx] = getHession_F('f1');
[Hy,Fy] = getHession_F('f2');
param.Hx = Hx;
param.Fx = Fx;
param.Hy = Hy;
param.Fy = Fy;
%% solve problem using admm algrithm
[x,y] = solve_admm(param);
%% disp minimum
disp(['[x,y]:' num2str(x) ',' num2str(y)]);
2、针对本问题(二次规划),构造交替求解对应表达式,为了用于solve_admm.m的二次规划求解器函数quadprog求解。即需要得到ALM的Hessian矩阵及变量的一次项系数
关于quadprog函数,请参见matlab官方指南:quadprog-二次规划
构造子函数getHession_F.m:
function [H,F] = getHession_F(fn)
% 目的:服务于solve_admm.m的二次规划求解器函数quadprog求解
% fn : function name
% H : hessian matrix
% F : 一次项系数
syms x y lambda rho;
if strcmp(fn,'f1') % 判断输入函数是 f1 的话
f = (x-1)^2 + lambda*(2*x + 3*y -5) + rho/2*(2*x + 3*y -5)^2;
H = hessian(f,x); % 计算函数的 Hessian 矩阵 (2阶导数)
F = (2*lambda + (rho*(12*y - 20))/2 - 2); % x 一次项的系数:其实是由collect(f,{'x'})的x一次项系数得到
fcol = collect(f,{'x'}); % 固定y,默认x为符号变量,输出关于x的表达式 (Collect(s,v)命令用于将符号矩阵S中所有同类项合并,并以v为符号变量输出;)
disp(fcol); % 输出
elseif strcmp(fn,'f2') % 判断输入函数是 f2 的话
f = (y-2)^2 + lambda*(2*x + 3*y -5) + rho/2*(2*x + 3*y -5)^2;
H = hessian(f,y);
F = (3*lambda + (rho*(12*x - 30))/2 - 4); % y 的系数
fcol = collect(f,{'y'}); % 固定x,默认y为符号变量
disp(fcol);
end
end
3、ADMM子函数,用于交替优化求解x、y,子函数为solve_admm.m:
function [x,y] = solve_admm(param)
x = param.x0;
y = param.y0;
lambda = param.lambda;
beta = param.beta;
rho = param.rho;
Hx = param.Hx;
Fx = param.Fx;
Hy = param.Hy;
Fy = param.Fy;
%%
xlb = 0; % the lower bound of x
xub = 3; % the upper bound of x
ylb = 1;
yub = 4;
maxIter = param.maxIter;
i = 1;
funval = zeros(maxIter-1,1);
iterNum = zeros(maxIter-1,1);
while 1
if i == maxIter
break;
end
% solve x
Hxx = eval(Hx); % eval:可以把字符串当作命令来执行
Fxx = eval(Fx);
x = quadprog(Hxx,Fxx,[],[],[],[],xlb,xub,[]); % Quadratic programming function
% solve y
Hyy = eval(Hy);
Fyy = eval(Fy);
y = quadprog(Hyy,Fyy,[],[],[],[],ylb,yub,[]);
% update lambda
lambda = lambda + rho*(2*x + 3*y -5); % ascend 更新下降
funval(i) = compute_fval(x,y);
iterNum(i) = i;
i = i + 1;
end
plot(iterNum,funval,'-r');
end
4、原问题的目标函数,compute_fval.m:
function fval = compute_fval(x,y)
fval = (x-1)^2 + (y-2)^2;
end
5、运行结果:
命令行输出:
[x,y]:0.53846,1.3077