LM法 java版

package com.reallife.arithmetic;

import android.util.Log;
import Jama.Matrix;

public class GaussNewton {
	// double[] xData = new double[]{36, 37, 38, 39, 40, 41, 42, 43, 44, 45};
	// double[] yData = new double[]{2475.8, 2854.63, 3141.37, 3333.83,
	// 3455.6, 3522.2, 3557.77, 3568.27, 3573.63, 3586.83};
	private final String TAG = "GaussNewton";
	private double[] xData;
	private double[] yData;
	
	private double def_residual = 0.01;//默认迭代收敛条件
	private double residual = 0;//当前拟合精度
	private double lamda = 0.01;//迭代步长
	private double v = 10.0;//步长倍数
//	private boolean isUpdate = false;//是否需要更新雅可比矩阵
	
	private double[][] unitMatrix;// 单位阵
	private double[] bMatrix;// 系数β矩阵
	private double[] init_b;// 系数β矩阵
	private int index = 0;//计算方程的编号
	private boolean reCalculate = false;
	private int count = 0;
	private int c_count = 0;
	private int m;
	private int n;
	private int iterations = 1;// 迭代次数
	
	/**
	 * @param data
	 * 
	 * */
	public GaussNewton(double[][] data) {
		super();

		xData = new double[data.length];
		yData = new double[data.length];

		for (int i = 0; i < data.length; i++) {
			xData[i] = data[i][0];
			yData[i] = data[i][1]; 
//			Log.e(TAG, "GaussNewton  "+xData[i]+"  "+yData[i]);
		}
	}
	
	/**
	 * @param xData
	 * @param yData
	 * 
	 * */
	public GaussNewton(double[] xData, double[] yData) {
		super();

		this.xData = xData;
		this.yData = yData;
	}
	
	/**
	 * @param row 求解的未知参数个数
	 * @param init_b 未知参数初值数组
	 * @param def_residual 0表示使用默认值
	 * */
	public void setInit(int row, double[] init_b, double def_residual, int iterations, int index) throws MyAlgorithmException{
		// β1,β2迭代初值
		// bMatrix[0][0] = 23.4502;
		// bMatrix[1][0] = 0.4553;
		// bMatrix[2][0] = 3621.22;
		if (row == init_b.length) {
			
			if (def_residual < 0) {
				throw new MyAlgorithmException("Def_residual must be a number greater than 0.",2);
			}else if (def_residual > 0) {
				this.def_residual = def_residual;
			}
			
			if (iterations != 0) {
				this.iterations = iterations;
			}
			
			unitMatrix = new double[row][row];
			for (int i = 0; i < row; i++) {
				for (int j = 0; j < row; j++) {
					if (i==j) {
						unitMatrix[i][j] = 1;
					}else {
						unitMatrix[i][j] = 0;
					}
				}
			}
			
			if (!reCalculate) {
				this.init_b = new double[row];
				count = 0;
				c_count = 0;
			}
			bMatrix = new double[row];
			m = xData.length;
			n = row;
			//1、选取初值点bMatrix、拟合精度def_residual、迭代次数iterations、初始步长lamda、步长倍数
			lamda = 0.01;
			
			for (int i = 0; i < row; i++) {
				bMatrix[i] = init_b[i];
				if (!reCalculate) {
					this.init_b[i] = init_b[i];
				}
			}
			scal = init_b[1]/init_b[0];
			
			this.index = index;
			
		}else {
			throw new MyAlgorithmException("Row must be the same as the length of init_b.",0);
		}
	}

	/**
	 * 迭代公式求解,即1中公式⑩
	 * */
	public double[] magic() throws MyAlgorithmException {
		
		if (bMatrix == null) {
			throw new MyAlgorithmException("The initial value of bMatrix must be given.",1);
		}else {
			double[][] J = new double[m][n];
			double[][] JT = new double[n][m];
			double[][] invertedPart = null;
			
			//2、计算拟合精度 ------求方差r(β)矩阵: ri = yi - f(xi, b)
			double[][] residuals = new double[m][1];//残差集合
			double newresidual = 0;//残差
			for (int i = 0; i < m; i++) {
				if (index == 1) {
					residuals[i][0] = yData[i] - (bMatrix[2]/(1+Math.exp(bMatrix[0]-bMatrix[1]*xData[i])));
					if (i==0) {
						newresidual = Math.abs(residuals[i][0]);
					}else {
						if (newresidual > Math.abs(residuals[i][0])) {
							newresidual = Math.abs(residuals[i][0]);
						}
					}
//					newresidual += residuals[i][0]*residuals[i][0];
				}else {
					residuals[i][0] = yData[i] - (bMatrix[2] - Math.exp(bMatrix[0] - bMatrix[1] * xData[i]));
					if (i==0) {
						newresidual = Math.abs(residuals[i][0]);
					}else {
						if (newresidual > Math.abs(residuals[i][0])) {
							newresidual = Math.abs(residuals[i][0]);
						}
					}
//					newresidual += residuals[i][0]*residuals[i][0];
				}
			}
			residual = newresidual;
//			System.out.println("Step2、计算拟合精度   " + newresidual);
			
			for (int k = 0; k < iterations; k++) {
//				System.out.println("-------------------------------------------");
//				System.out.println("---------------------"+k+"--------------------");
				
				//3、计算雅可比矩阵
//				if (isUpdate || k == 0) {
					for (int i = 0; i < m; i++) {
						for (int j = 0; j < n; j++) {
							if (index == 1) {
								J[i][j] = secondDerivative1(xData[i], bMatrix[0], bMatrix[1], bMatrix[2], j);
//							System.out.println("Step3、计算雅可比矩阵   J[" + i + "]["+j+"]  " + J[i][j]);
							}else {
								J[i][j] = secondDerivative2(xData[i], bMatrix[0], bMatrix[1], bMatrix[2], j);
//							System.out.println("Step3、计算雅可比矩阵   J[" + i + "]["+j+"]  " + J[i][j]);
							}
						}
					}
//				}
				
				
				//4、计算JTJ+lamda*I
				//求转置矩阵JT
				JT = MatrixMath.invert(J);
				// 矩阵JT与J相乘
				invertedPart = MatrixMath.mult(JT, J);
				double[][] unitPart = MatrixMath.mult(unitMatrix, lamda);
				double[][] totalPart = MatrixMath.plus(invertedPart, unitPart);
				
				//5、求解delta
				double[][] reversedPart;
				try {
					// 求矩阵invertedPart的逆矩阵:(JT*J+lamda*I)^-1 
					Matrix reversedMatrix = new Matrix(totalPart);
					reversedPart = reversedMatrix.inverse().getArray();
				} catch (Exception e) {
					// 矩阵invertedPart行列式的值:|JT*J+lamda*I|
					double result = MatrixMath.mathDeterminantCalculation(totalPart);
					// 求矩阵invertedPart的逆矩阵:(JT*J+lamda*I)^-1 
					reversedPart = MatrixMath.getInverseMatrix(totalPart, result);
//					e.printStackTrace();
				}

//				System.out.println("-----------reversedPart  "+k+"-------------");
//				for (int i = 0; i < reversedPart.length; i++) {
//					for (int l = 0; l < reversedPart[i].length; l++) {
//						System.out.println("reversedPart[" + i + "]["+l+"]  " + reversedPart[i][l]);
//					}
//				}
				
				// 求矩阵积reversedPart*JT*residuals: (JT*J+lamda*I)^-1*JT*r
				double[][] products = MatrixMath.mult(MatrixMath.mult(reversedPart, JT), residuals);
				double[] product = new double[products.length];
				double[] product_temporary = new double[products.length];
				double delta = 0;
				for (int i = 0; i < products.length; i++) {
					product[i] = products[i][0];
					product_temporary[i] = bMatrix[i]+product[i];
					if (i==0) {
						delta = Math.abs(product[i]);
					}else {
						if (delta > Math.abs(product[i])) {
							delta = Math.abs(product[i]);
						}
					}
//					for (int l = 0; l < products[i].length; l++) {
//						System.out.println("Step5、求解delta  product[" + i + "][" + l + "]  " + products[i][l]);
//					}
				}
				
				//6、(1)判断当前精度是否小于上一次的精度
				for (int i = 0; i < m; i++) {
					if (index == 1) {
						residuals[i][0] = yData[i] - (product_temporary[2]/(1+Math.exp(product_temporary[0]-product_temporary[1]*xData[i])));
						if (i==0) {
							newresidual = Math.abs(residuals[i][0]);
						}else {
							if (newresidual > Math.abs(residuals[i][0])) {
								newresidual = Math.abs(residuals[i][0]);
							}
						}
//						newresidual += residuals[i][0]*residuals[i][0];
					}else {
						residuals[i][0] = yData[i] - (product_temporary[2] - Math.exp(product_temporary[0] - product_temporary[1] * xData[i]));
						if (i==0) {
							newresidual = Math.abs(residuals[i][0]);
						}else {
							if (newresidual > Math.abs(residuals[i][0])) {
								newresidual = Math.abs(residuals[i][0]);
							}
						}
//						newresidual += residuals[i][0]*residuals[i][0];
					}
				}
//				System.out.println("Step6、判断当前精度是否小于上一次的精度  newresidual  " + newresidual+"  "+(newresidual <= residual));
//				if (newresidual <= residual) {
					//7、(1.1)是,则迭代Xn+1=Xn+delta
					// 迭代公式, 即公式⑩
					bMatrix = MatrixMath.plus(bMatrix, product);
					residual = newresidual;
//					for (int i = 0; i < bMatrix.length; i++) {
//						System.out.println("Step7、迭代公式, 即公式⑩  bMatrix[" + i + "]  " + bMatrix[i]);
//					}
//					isUpdate = true;
					
					//8、(2)判断是否delta满足拟合精度
					if (Math.abs(product[0]) < def_residual && Math.abs(product[1]) < def_residual && Math.abs(product[2]) < def_residual) {
						//9、(2.1)是,则结束迭代
//						System.out.println("Step9、(2.1)是,则结束迭代  "+k+"  lamda  "+lamda);
//						for (int i = 0; i < bMatrix.length; i++) {
//							System.out.println(bMatrix[i]);
//						}
						break;
					}else {
						//10、(2.2)否,lamda=lamda/v,n=n+1,重复4
						if (newresidual < residual) {
							lamda=lamda/v;
						}else {
							lamda=lamda*v;
						}
//						lamda=lamda/v;
//						System.out.println("Step10、lamda  " + lamda);
					}
//				}else {
//					//11、(1.2)否,lamda=lamda*v,n=n+1,重复4
//					
//					bMatrix = MatrixMath.plus(bMatrix, product);
//					// 迭代公式, 即公式⑩
//					if (Math.abs(product[0]) < def_residual && Math.abs(product[1]) < def_residual && Math.abs(product[2]) < def_residual) {
//						residual = newresidual;
//						for (int i = 0; i < bMatrix.length; i++) {
//							System.out.println("Step11、满足条件,退出循环  bMatrix[" + i + "]  " + bMatrix[i]);
//						}
//						break;
//					}else {
//						isUpdate = false;
//						lamda=lamda*v;
//					}
//					System.out.println("Step11、lamda  " + lamda);
//				}
				
				
//				if (k == 0) {
//					residual = newresidual;
//				}else {
//					System.out.println("-----------residual  "+residual+"  "+newresidual+"-------------");
//					if (newresidual > residual) {
//						isMore = true;
//						lamda = lamda*10;
//					}else {
//						isMore = false;
//						lamda = lamda/10;
//					}
//					residual = newresidual;
//				}
//				
//				System.out.println("-----------lamda  "+lamda+"-------------");
//				
//				//判断product是否满足条件,若满足条件则跳出迭代循环
//				boolean iteration_end = true;
//				for (int i = 0; i < product.length; i++) {
//					// || Double.isNaN(product[i][0])
//					if (Math.abs(product[i]) > def_residual) {
//						iteration_end = false;
//					}
//				}
//				if (iteration_end) {
//					break;
//				}
			}
			
			// 计算拟合优度
			
//			// 显示系数值
//			System.out.println("----------------------------------");
//			System.out.println("b1: " + bMatrix[0] + "\nb2: " + bMatrix[1] + "\nb3: " + bMatrix[2]+"  "+(Math.abs(bMatrix[2]) > Math.pow(10, 10)));
//			System.out.println("***********"+c_count+"**********"+count+"*************");
			if (index == 1 && (Double.isNaN(bMatrix[0]) || Double.isNaN(bMatrix[1]) || Double.isNaN(bMatrix[2])
					|| bMatrix[0] < 0
					|| bMatrix[1]/bMatrix[0] > 0.035 || bMatrix[0] < 10)) {
				
				if (count == 0) {
					if (bMatrix[1]/bMatrix[0] > 0.035) {
						scal = 0.015;
					}
				}else {
					if (scal >= 0.035) {
						scal = 0.035;
					}else {
						scal = 0.015+0.001*count;
					}
				}
				
				count ++;
				
				if (count < 22) {
					double r2 = calculateR2(bMatrix);
					if (r2 > best_r2) {
						best_r2 = r2;
						best_value[0] = bMatrix[0];
						best_value[1] = bMatrix[1];
						best_value[2] = bMatrix[2];
					}
					bMatrix[0] = init_b[0];
					bMatrix[1] = init_b[0]*scal;
					bMatrix[2] = init_b[2];
//					System.out.println(">>>  count < 22 init change  "+bMatrix[0]+"  "+bMatrix[1]+"  "+bMatrix[2]);
					//重新计算  初始化计数参数
					reCalculate = true;
					setInit(bMatrix.length, bMatrix, def_residual, iterations, index);
					magic();
				}else {
					bMatrix[0] = best_value[0];
					bMatrix[1] = best_value[1];
					bMatrix[2] = best_value[2];
				}
//				else {
//					c_count ++;
//					if (c_count <= 30) {
//						count = 0;
//						scal = 0.015;
//						bMatrix[0][0] = init_b[0];
//						bMatrix[1][0] = init_b[0]*scal;
//						bMatrix[2][0] = init_b[2]*Math.pow(1.4, c_count);
//						System.out.println(">>>  c_count < 5 init change  "+bMatrix[0][0]+"  "+bMatrix[1][0]+"  "+bMatrix[2][0]);
//						magic();
//					}
//				}
			}
		}
		return bMatrix;
	}
	private double[] best_value = new double[3];
	private double best_r2 = 0;
	private double scal = 0;
	
	private double calculateR2(double[] bMatrix){
		double R2 = 0;
		
		/**
		 * (1)计算残差平方和Q=∑(y-y*)^2和∑y^2,其中,y代表的是实测值,y*代表的是预测值;
		 * (2)拟合度指标RNew=1-(Q/∑y^2)^(1/2)
		 * */
		double residual = 0;//残差平方和
		double quadratic = 0;//平方和
		
		int length = yData.length;
		double value = 0;
		for (int i = 0; i < length; i++) {
			if (index == 1) {
				value = (float) (bMatrix[2]/(1+Math.exp(bMatrix[0]-bMatrix[1]*(i+1))));
			}else {
				value = (float) (bMatrix[2] - Math.exp(bMatrix[0] - bMatrix[1] * (i+1)));
			}
			residual += Math.pow(yData[i]-value,2.0);
			quadratic += Math.pow(yData[i],2.0);
		}
		R2 = 1-Math.sqrt(residual/quadratic);
//		Log.e(TAG, "R2  "+R2);
		
		return R2;
	}

	/** c-Math.exp(a-b*x)*/
	private static double secondDerivative2(double x, double b1, double b2, double b3, int index) {
		switch (index) {
			case 0:
				return -Math.exp(b1 - b2 * x);// 对系数b1求导
			case 1:
				return x * Math.exp(b1 - b2 * x);// 对系数b2求导
			case 2:
				return 1;// 对系数b3求导
		}
		return 0;
	}

	/** c(Math.pow((1+Math.exp(a-b*x)), 2)))*/
	private static double secondDerivative1(double x, double b1, double b2, double b3, int index) {
		switch (index) {
			case 0:
				return -b3*(Math.exp(b1-b2*x)/(Math.pow((1+Math.exp(b1-b2*x)), 2)));// 对系数b1求导
			case 1:
				return x*b3*(Math.exp(b1-b2*x)/(Math.pow((1+Math.exp(b1-b2*x)), 2)));// 对系数b2求导
			case 2:
				return 1/(1+Math.exp(b1-b2*x));// 对系数b3求导
		}
		return 0;
	}

}

算法的原理就不多啰嗦了,反正网上一大把

可以参考原论文:下载地址

可以参考K. Madsen等人的《Methods for non-linear least squares problems》:下载地址

 

看过LM算法的应该都知道,LM是在高斯牛顿法的基础上修改的,所以可以先去了解一下高斯牛顿法,其实楼主也是一知半解的,但是根据改编的方法可以解决楼主的问题,所以也就不深究了。如有大神,请勿喷。

 

因为是在项目中需要进行曲线拟合,需要非线性最小二乘法然后用LM法进行迭代求最优解。算法是根据原理一步一步写下来的,在实际使用过程中有问题,所以做了一些改编。本文只是作为学习笔记,如有需要使用的,可以参考下,不推荐照搬照抄。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值