package com.reallife.arithmetic;
import android.util.Log;
import Jama.Matrix;
public class GaussNewton {
// double[] xData = new double[]{36, 37, 38, 39, 40, 41, 42, 43, 44, 45};
// double[] yData = new double[]{2475.8, 2854.63, 3141.37, 3333.83,
// 3455.6, 3522.2, 3557.77, 3568.27, 3573.63, 3586.83};
private final String TAG = "GaussNewton";
private double[] xData;
private double[] yData;
private double def_residual = 0.01;//默认迭代收敛条件
private double residual = 0;//当前拟合精度
private double lamda = 0.01;//迭代步长
private double v = 10.0;//步长倍数
// private boolean isUpdate = false;//是否需要更新雅可比矩阵
private double[][] unitMatrix;// 单位阵
private double[] bMatrix;// 系数β矩阵
private double[] init_b;// 系数β矩阵
private int index = 0;//计算方程的编号
private boolean reCalculate = false;
private int count = 0;
private int c_count = 0;
private int m;
private int n;
private int iterations = 1;// 迭代次数
/**
* @param data
*
* */
public GaussNewton(double[][] data) {
super();
xData = new double[data.length];
yData = new double[data.length];
for (int i = 0; i < data.length; i++) {
xData[i] = data[i][0];
yData[i] = data[i][1];
// Log.e(TAG, "GaussNewton "+xData[i]+" "+yData[i]);
}
}
/**
* @param xData
* @param yData
*
* */
public GaussNewton(double[] xData, double[] yData) {
super();
this.xData = xData;
this.yData = yData;
}
/**
* @param row 求解的未知参数个数
* @param init_b 未知参数初值数组
* @param def_residual 0表示使用默认值
* */
public void setInit(int row, double[] init_b, double def_residual, int iterations, int index) throws MyAlgorithmException{
// β1,β2迭代初值
// bMatrix[0][0] = 23.4502;
// bMatrix[1][0] = 0.4553;
// bMatrix[2][0] = 3621.22;
if (row == init_b.length) {
if (def_residual < 0) {
throw new MyAlgorithmException("Def_residual must be a number greater than 0.",2);
}else if (def_residual > 0) {
this.def_residual = def_residual;
}
if (iterations != 0) {
this.iterations = iterations;
}
unitMatrix = new double[row][row];
for (int i = 0; i < row; i++) {
for (int j = 0; j < row; j++) {
if (i==j) {
unitMatrix[i][j] = 1;
}else {
unitMatrix[i][j] = 0;
}
}
}
if (!reCalculate) {
this.init_b = new double[row];
count = 0;
c_count = 0;
}
bMatrix = new double[row];
m = xData.length;
n = row;
//1、选取初值点bMatrix、拟合精度def_residual、迭代次数iterations、初始步长lamda、步长倍数
lamda = 0.01;
for (int i = 0; i < row; i++) {
bMatrix[i] = init_b[i];
if (!reCalculate) {
this.init_b[i] = init_b[i];
}
}
scal = init_b[1]/init_b[0];
this.index = index;
}else {
throw new MyAlgorithmException("Row must be the same as the length of init_b.",0);
}
}
/**
* 迭代公式求解,即1中公式⑩
* */
public double[] magic() throws MyAlgorithmException {
if (bMatrix == null) {
throw new MyAlgorithmException("The initial value of bMatrix must be given.",1);
}else {
double[][] J = new double[m][n];
double[][] JT = new double[n][m];
double[][] invertedPart = null;
//2、计算拟合精度 ------求方差r(β)矩阵: ri = yi - f(xi, b)
double[][] residuals = new double[m][1];//残差集合
double newresidual = 0;//残差
for (int i = 0; i < m; i++) {
if (index == 1) {
residuals[i][0] = yData[i] - (bMatrix[2]/(1+Math.exp(bMatrix[0]-bMatrix[1]*xData[i])));
if (i==0) {
newresidual = Math.abs(residuals[i][0]);
}else {
if (newresidual > Math.abs(residuals[i][0])) {
newresidual = Math.abs(residuals[i][0]);
}
}
// newresidual += residuals[i][0]*residuals[i][0];
}else {
residuals[i][0] = yData[i] - (bMatrix[2] - Math.exp(bMatrix[0] - bMatrix[1] * xData[i]));
if (i==0) {
newresidual = Math.abs(residuals[i][0]);
}else {
if (newresidual > Math.abs(residuals[i][0])) {
newresidual = Math.abs(residuals[i][0]);
}
}
// newresidual += residuals[i][0]*residuals[i][0];
}
}
residual = newresidual;
// System.out.println("Step2、计算拟合精度 " + newresidual);
for (int k = 0; k < iterations; k++) {
// System.out.println("-------------------------------------------");
// System.out.println("---------------------"+k+"--------------------");
//3、计算雅可比矩阵
// if (isUpdate || k == 0) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
if (index == 1) {
J[i][j] = secondDerivative1(xData[i], bMatrix[0], bMatrix[1], bMatrix[2], j);
// System.out.println("Step3、计算雅可比矩阵 J[" + i + "]["+j+"] " + J[i][j]);
}else {
J[i][j] = secondDerivative2(xData[i], bMatrix[0], bMatrix[1], bMatrix[2], j);
// System.out.println("Step3、计算雅可比矩阵 J[" + i + "]["+j+"] " + J[i][j]);
}
}
}
// }
//4、计算JTJ+lamda*I
//求转置矩阵JT
JT = MatrixMath.invert(J);
// 矩阵JT与J相乘
invertedPart = MatrixMath.mult(JT, J);
double[][] unitPart = MatrixMath.mult(unitMatrix, lamda);
double[][] totalPart = MatrixMath.plus(invertedPart, unitPart);
//5、求解delta
double[][] reversedPart;
try {
// 求矩阵invertedPart的逆矩阵:(JT*J+lamda*I)^-1
Matrix reversedMatrix = new Matrix(totalPart);
reversedPart = reversedMatrix.inverse().getArray();
} catch (Exception e) {
// 矩阵invertedPart行列式的值:|JT*J+lamda*I|
double result = MatrixMath.mathDeterminantCalculation(totalPart);
// 求矩阵invertedPart的逆矩阵:(JT*J+lamda*I)^-1
reversedPart = MatrixMath.getInverseMatrix(totalPart, result);
// e.printStackTrace();
}
// System.out.println("-----------reversedPart "+k+"-------------");
// for (int i = 0; i < reversedPart.length; i++) {
// for (int l = 0; l < reversedPart[i].length; l++) {
// System.out.println("reversedPart[" + i + "]["+l+"] " + reversedPart[i][l]);
// }
// }
// 求矩阵积reversedPart*JT*residuals: (JT*J+lamda*I)^-1*JT*r
double[][] products = MatrixMath.mult(MatrixMath.mult(reversedPart, JT), residuals);
double[] product = new double[products.length];
double[] product_temporary = new double[products.length];
double delta = 0;
for (int i = 0; i < products.length; i++) {
product[i] = products[i][0];
product_temporary[i] = bMatrix[i]+product[i];
if (i==0) {
delta = Math.abs(product[i]);
}else {
if (delta > Math.abs(product[i])) {
delta = Math.abs(product[i]);
}
}
// for (int l = 0; l < products[i].length; l++) {
// System.out.println("Step5、求解delta product[" + i + "][" + l + "] " + products[i][l]);
// }
}
//6、(1)判断当前精度是否小于上一次的精度
for (int i = 0; i < m; i++) {
if (index == 1) {
residuals[i][0] = yData[i] - (product_temporary[2]/(1+Math.exp(product_temporary[0]-product_temporary[1]*xData[i])));
if (i==0) {
newresidual = Math.abs(residuals[i][0]);
}else {
if (newresidual > Math.abs(residuals[i][0])) {
newresidual = Math.abs(residuals[i][0]);
}
}
// newresidual += residuals[i][0]*residuals[i][0];
}else {
residuals[i][0] = yData[i] - (product_temporary[2] - Math.exp(product_temporary[0] - product_temporary[1] * xData[i]));
if (i==0) {
newresidual = Math.abs(residuals[i][0]);
}else {
if (newresidual > Math.abs(residuals[i][0])) {
newresidual = Math.abs(residuals[i][0]);
}
}
// newresidual += residuals[i][0]*residuals[i][0];
}
}
// System.out.println("Step6、判断当前精度是否小于上一次的精度 newresidual " + newresidual+" "+(newresidual <= residual));
// if (newresidual <= residual) {
//7、(1.1)是,则迭代Xn+1=Xn+delta
// 迭代公式, 即公式⑩
bMatrix = MatrixMath.plus(bMatrix, product);
residual = newresidual;
// for (int i = 0; i < bMatrix.length; i++) {
// System.out.println("Step7、迭代公式, 即公式⑩ bMatrix[" + i + "] " + bMatrix[i]);
// }
// isUpdate = true;
//8、(2)判断是否delta满足拟合精度
if (Math.abs(product[0]) < def_residual && Math.abs(product[1]) < def_residual && Math.abs(product[2]) < def_residual) {
//9、(2.1)是,则结束迭代
// System.out.println("Step9、(2.1)是,则结束迭代 "+k+" lamda "+lamda);
// for (int i = 0; i < bMatrix.length; i++) {
// System.out.println(bMatrix[i]);
// }
break;
}else {
//10、(2.2)否,lamda=lamda/v,n=n+1,重复4
if (newresidual < residual) {
lamda=lamda/v;
}else {
lamda=lamda*v;
}
// lamda=lamda/v;
// System.out.println("Step10、lamda " + lamda);
}
// }else {
// //11、(1.2)否,lamda=lamda*v,n=n+1,重复4
//
// bMatrix = MatrixMath.plus(bMatrix, product);
// // 迭代公式, 即公式⑩
// if (Math.abs(product[0]) < def_residual && Math.abs(product[1]) < def_residual && Math.abs(product[2]) < def_residual) {
// residual = newresidual;
// for (int i = 0; i < bMatrix.length; i++) {
// System.out.println("Step11、满足条件,退出循环 bMatrix[" + i + "] " + bMatrix[i]);
// }
// break;
// }else {
// isUpdate = false;
// lamda=lamda*v;
// }
// System.out.println("Step11、lamda " + lamda);
// }
// if (k == 0) {
// residual = newresidual;
// }else {
// System.out.println("-----------residual "+residual+" "+newresidual+"-------------");
// if (newresidual > residual) {
// isMore = true;
// lamda = lamda*10;
// }else {
// isMore = false;
// lamda = lamda/10;
// }
// residual = newresidual;
// }
//
// System.out.println("-----------lamda "+lamda+"-------------");
//
// //判断product是否满足条件,若满足条件则跳出迭代循环
// boolean iteration_end = true;
// for (int i = 0; i < product.length; i++) {
// // || Double.isNaN(product[i][0])
// if (Math.abs(product[i]) > def_residual) {
// iteration_end = false;
// }
// }
// if (iteration_end) {
// break;
// }
}
// 计算拟合优度
// // 显示系数值
// System.out.println("----------------------------------");
// System.out.println("b1: " + bMatrix[0] + "\nb2: " + bMatrix[1] + "\nb3: " + bMatrix[2]+" "+(Math.abs(bMatrix[2]) > Math.pow(10, 10)));
// System.out.println("***********"+c_count+"**********"+count+"*************");
if (index == 1 && (Double.isNaN(bMatrix[0]) || Double.isNaN(bMatrix[1]) || Double.isNaN(bMatrix[2])
|| bMatrix[0] < 0
|| bMatrix[1]/bMatrix[0] > 0.035 || bMatrix[0] < 10)) {
if (count == 0) {
if (bMatrix[1]/bMatrix[0] > 0.035) {
scal = 0.015;
}
}else {
if (scal >= 0.035) {
scal = 0.035;
}else {
scal = 0.015+0.001*count;
}
}
count ++;
if (count < 22) {
double r2 = calculateR2(bMatrix);
if (r2 > best_r2) {
best_r2 = r2;
best_value[0] = bMatrix[0];
best_value[1] = bMatrix[1];
best_value[2] = bMatrix[2];
}
bMatrix[0] = init_b[0];
bMatrix[1] = init_b[0]*scal;
bMatrix[2] = init_b[2];
// System.out.println(">>> count < 22 init change "+bMatrix[0]+" "+bMatrix[1]+" "+bMatrix[2]);
//重新计算 初始化计数参数
reCalculate = true;
setInit(bMatrix.length, bMatrix, def_residual, iterations, index);
magic();
}else {
bMatrix[0] = best_value[0];
bMatrix[1] = best_value[1];
bMatrix[2] = best_value[2];
}
// else {
// c_count ++;
// if (c_count <= 30) {
// count = 0;
// scal = 0.015;
// bMatrix[0][0] = init_b[0];
// bMatrix[1][0] = init_b[0]*scal;
// bMatrix[2][0] = init_b[2]*Math.pow(1.4, c_count);
// System.out.println(">>> c_count < 5 init change "+bMatrix[0][0]+" "+bMatrix[1][0]+" "+bMatrix[2][0]);
// magic();
// }
// }
}
}
return bMatrix;
}
private double[] best_value = new double[3];
private double best_r2 = 0;
private double scal = 0;
private double calculateR2(double[] bMatrix){
double R2 = 0;
/**
* (1)计算残差平方和Q=∑(y-y*)^2和∑y^2,其中,y代表的是实测值,y*代表的是预测值;
* (2)拟合度指标RNew=1-(Q/∑y^2)^(1/2)
* */
double residual = 0;//残差平方和
double quadratic = 0;//平方和
int length = yData.length;
double value = 0;
for (int i = 0; i < length; i++) {
if (index == 1) {
value = (float) (bMatrix[2]/(1+Math.exp(bMatrix[0]-bMatrix[1]*(i+1))));
}else {
value = (float) (bMatrix[2] - Math.exp(bMatrix[0] - bMatrix[1] * (i+1)));
}
residual += Math.pow(yData[i]-value,2.0);
quadratic += Math.pow(yData[i],2.0);
}
R2 = 1-Math.sqrt(residual/quadratic);
// Log.e(TAG, "R2 "+R2);
return R2;
}
/** c-Math.exp(a-b*x)*/
private static double secondDerivative2(double x, double b1, double b2, double b3, int index) {
switch (index) {
case 0:
return -Math.exp(b1 - b2 * x);// 对系数b1求导
case 1:
return x * Math.exp(b1 - b2 * x);// 对系数b2求导
case 2:
return 1;// 对系数b3求导
}
return 0;
}
/** c(Math.pow((1+Math.exp(a-b*x)), 2)))*/
private static double secondDerivative1(double x, double b1, double b2, double b3, int index) {
switch (index) {
case 0:
return -b3*(Math.exp(b1-b2*x)/(Math.pow((1+Math.exp(b1-b2*x)), 2)));// 对系数b1求导
case 1:
return x*b3*(Math.exp(b1-b2*x)/(Math.pow((1+Math.exp(b1-b2*x)), 2)));// 对系数b2求导
case 2:
return 1/(1+Math.exp(b1-b2*x));// 对系数b3求导
}
return 0;
}
}
算法的原理就不多啰嗦了,反正网上一大把
可以参考原论文:下载地址
可以参考K. Madsen等人的《Methods for non-linear least squares problems》:下载地址
看过LM算法的应该都知道,LM是在高斯牛顿法的基础上修改的,所以可以先去了解一下高斯牛顿法,其实楼主也是一知半解的,但是根据改编的方法可以解决楼主的问题,所以也就不深究了。如有大神,请勿喷。
因为是在项目中需要进行曲线拟合,需要非线性最小二乘法然后用LM法进行迭代求最优解。算法是根据原理一步一步写下来的,在实际使用过程中有问题,所以做了一些改编。本文只是作为学习笔记,如有需要使用的,可以参考下,不推荐照搬照抄。