高斯牛顿迭代法

高斯牛顿迭代法

标签: 机器学习高斯牛顿迭代法
  618人阅读  评论(0)  收藏  举报
  分类:

转载:http://www.tqcto.com/article/code/302336.html#


本文将详解最小二乘法的非线性拟合高斯牛顿迭代法

1.原理

高斯—牛顿迭代法的基本思想是使用泰勒级数展开式去近似地代替非线性回归模型,然后通过多次迭代,多次修正回归系数,使回归系数不断逼近非线性回归模型的最佳回归系数,最后使原模型的残差平方和达到最小。

①已知m个点:

这里写图片描述

②函数原型:

这里写图片描述

其中:(m>=n)

这里写图片描述

③目的是找到最优解β,使得残差平方和最小:

这里写图片描述

残差:

这里写图片描述

④要求最小值,即S的对β偏导数等于0:

这里写图片描述

⑤在非线性系统中,这里写图片描述是变量和参数的函数,没有close解。因此,我们给定一个初始值,用迭代法逼近解:

这里写图片描述

其中k是迭代次数,这里写图片描述是迭代矢量。

⑥而每次迭代函数是线性的,在这里写图片描述处用泰勒级数展开:

这里写图片描述

其中:J是已知的矩阵,为了方便迭代,令这里写图片描述

⑦此时残差表示为:

这里写图片描述

这里写图片描述

⑧带入公式④有:

这里写图片描述

化解得:

这里写图片描述

⑨写成矩阵形式:

这里写图片描述

⑩所以最终迭代公式为:

这里写图片描述

其中,Jf是函数f=(x,β)对β的雅可比矩阵。

2.代码

Java代码实现,解维基百科的例子:

https://en.wikipedia.org/wiki/Gauss%E2%80%93Newton_algorithm

①已知数据:

这里写图片描述

②函数模型:

这里写图片描述

③残差公式:

这里写图片描述

④对β求偏导数:

这里写图片描述

⑤代码如下:

public class GaussNewton {
    double[] xData = new double[]{0.038, 0.194, 0.425, 0.626, 1.253, 2.500, 3.740};
    double[] yData = new double[]{0.050, 0.127, 0.094, 0.2122, 0.2729, 0.2665, 0.3317};

    double[][] bMatrix = new double[2][1];//系数β矩阵
    int m = xData.length;
    int n = bMatrix.length;
    int iterations = 7;//迭代次数

    //迭代公式求解,即1中公式⑩
    private void magic(){
        //β1,β2迭代初值
        bMatrix[0][0] = 0.9;
        bMatrix[1][0] = 0.2;

        //求J矩阵
        for(int k = 0; k < iterations; k++) { 
            double[][] J = new double[m][n];
            double[][] JT = new double[n][m];
            for(int i = 0; i < m; i++){
                for(int j = 0; j < n; j++) {
                    J[i][j] = secondDerivative(xData[i], bMatrix[0][0], bMatrix[1][0], j);
                }
            }

            JT = MatrixMath.invert(J);//求转置矩阵JT
            double[][] invertedPart = MatrixMath.mult(JT, J);//矩阵JT与J相乘

            //矩阵invertedPart行列式的值:|JT*J|
            double result = MatrixMath.mathDeterminantCalculation(invertedPart);

            //求矩阵invertedPart的逆矩阵:(JT*J)^-1
            double[][] reversedPart = MatrixMath.getInverseMatrix(invertedPart, result);

            //求方差r(β)矩阵: ri = yi - f(xi, b)
            double[][] residuals = new double[m][1];
            for(int i = 0; i < m; i++) {
                residuals[i][0] = yData[i] - (bMatrix[0][0] * xData[i]) / (bMatrix[1][0] + xData[i]);
            }

            //求矩阵积reversedPart*JT*residuals: (JT*J)^-1*JT*r
            double[][] product = MatrixMath.mult(MatrixMath.mult(reversedPart, JT), residuals);

            //迭代公式, 即公式⑩
            double[][] newB = MatrixMath.plus(bMatrix, product);
            bMatrix = newB;        
        }        
        //显示系数值
        System.out.println("b1: " + bMatrix[0][0] + "\nb2: " + bMatrix[1][0]);        
    }

    //2中公式④
    private static double secondDerivative(double x, double b1, double b2, int index){
        switch(index) {
            case 0: return x / (b2 + x);//对系数bi求导
            case 1: return -1 * (b1 * x) / Math.pow((b2+x), 2);//对系数b2求导
        }
        return 0;
    }

    public static void main(String[] args) {
        GaussNewton app = new GaussNewton();
        app.magic();        
    }
}
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

运行,输出得到:

b1: 0.3618366954234483 
b2: 0.5562654497238557

这里写图片描述

⑥其中用到的矩阵运算代码如下:

public class MatrixMath {

    /**
     * 矩阵基本运算:加、减、乘、除、转置
     */
        public final static int OPERATION_ADD = 1;  
        public final static int OPERATION_SUB = 2;  
        public final static int OPERATION_MUL = 4;        

        /**
         * 矩阵加法
         * @param a 加数
         * @param b 被加数
         * @return 和
         */
        public static double[][] plus(double[][] a, double[][] b){
            if(legalOperation(a, b, OPERATION_ADD)) { 
                for(int i=0; i<a.length; i++) {         
                    for(int j=0; j<a[0].length; j++) {  
                        a[i][j] = a[i][j] + b[i][j];  
                    }
                }   
            }
            return a;
        }

        /**
         * 矩阵减法
         * @param a 减数
         * @param b 被减数
         * @return 差
         */
        public static double[][] substract(double[][] a, double[][] b){
            if(legalOperation(a, b, OPERATION_SUB)) { 
                for(int i=0; i<a.length; i++) {  
                    for(int j=0; j<a[0].length; j++) {  
                        a[i][j] = a[i][j] - b[i][j];  
                    }
                }       
            }
            return a;
        }       

        /**
         * 判断矩阵行列是否符合运算
         * @param a 进行运算的矩阵
         * @param b 进行运算的矩阵
         * @param type 运算类型
         * @return 符合/不符合运算
         */
        private static boolean legalOperation(double[][] a, double[][] b, int type) {  
            boolean legal = true;  
            if(type == OPERATION_ADD || type == OPERATION_SUB)  
            {  
                if(a.length != b.length || a[0].length != b[0].length) {  
                    legal = false;  
                }  
            }   
            else if(type == OPERATION_MUL)  
            {  
                if(a[0].length != b.length) {  
                    legal = false;  
                }  
            }  
            return legal;  
        } 

        /**
         * 矩阵乘法
         * @param a 乘数
         * @param b 被乘数
         * @return 积
         */
        public static double[][] mult(double[][] a, double[][] b){
            if(legalOperation(a, b, OPERATION_MUL)) {
                double[][] result = new double[a.length][b[0].length];
                for(int i=0; i< a.length; i++) {  
                    for(int j=0; j< b[0].length; j++) {  
                        result[i][j] = calculateSingleResult(a, b, i, j);  
                    }
                }
                return result;
            }
            else
            {
                return null;
            }       
        }

        /**
         * 矩阵乘法
         * @param a 乘数
         * @param b 被乘数
         * @return 积
         */
        public static double[][] mult(double[][] a, int b) {  
            for(int i=0; i<a.length; i++) {  
                for(int j=0; j<a[0].length; j++) {  
                    a[i][j] = a[i][j] * b;  
                }  
            }  
            return a;  
        }

        /**
         * 乘数矩阵的行元素与被乘数矩阵列元素积的和
         * @param a 乘数矩阵
         * @param b 被乘数矩阵
         * @param row 行
         * @param column 列
         * @return 值
         */
        private static double calculateSingleResult(double[][] a, double[][] b, int row, int column) {  
            double result = 0.0;  
            for(int k = 0; k< a[0].length; k++) {  
                result += a[row][k] * b[k][column];  
            }  
            return result;  
        }  

        /**
         * 矩阵的转置
         * @param a 要转置的矩阵
         * @return 转置后的矩阵
         */
        public static double[][] invert(double[][] a){
            double[][] result = new double[a[0].length][a.length];
            for(int i=0;i<a.length;i++){
                for(int j=0;j<a[0].length;j++){  
                      result[j][i]=a[i][j];  
                  }  
              }  
            return result;
        }   

    /** 
     * 求可逆矩阵(使用代数余子式的形式) 
     */   
        /** 
         * 求传入的矩阵的逆矩阵 
         * @param value 需要转换的矩阵 
         * @return 逆矩阵 
         */  
        public static double[][] getInverseMatrix(double[][] value,double result){  
            double[][] transferMatrix = new double[value.length][value[0].length];  
            //计算代数余子式,并赋值给|A|  
            for (int i = 0; i < value.length; i++) {  
                for (int j = 0; j < value[i].length; j++) {  
                    transferMatrix[j][i] =  mathDeterminantCalculation(getNewMatrix(i, j, value));  
                    if ((i+j)%2!=0) {  
                        transferMatrix[j][i] = -transferMatrix[j][i];  
                    }  
                    transferMatrix[j][i] /= result;   
                }  
            }  
            return transferMatrix;  
        }  

        /*** 
         * 求行列式的值
         * @param value 需要算的行列式 
         * @return 计算的结果 
         */  
        public static double mathDeterminantCalculation(double[][] value){  
            if (value.length == 1) {  
                //当行列式为1阶的时候就直接返回本身  
                return value[0][0];  
            }

            if (value.length == 2) {  
                //如果行列式为二阶的时候直接进行计算  
                return value[0][0]*value[1][1]-value[0][1]*value[1][0];  
            }  

            //当行列式的阶数大于2时  
            double result = 1;  
            for (int i = 0; i < value.length; i++) {         
                //检查数组对角线位置的数值是否是0,如果是零则对该数组进行调换,查找到一行不为0的进行调换  
                if (value[i][i] == 0) {  
                    value = changeDeterminantNoZero(value,i,i);  
                    result*=-1;  
                }  

                for (int j = 0; j <i; j++) {  
                    //让开始处理的行的首位为0处理为三角形式  
                    //如果要处理的列为0则和自己调换一下位置,这样就省去了计算  
                    if (value[i][j] == 0) {  
                        continue;  
                    }  
                    //如果要是要处理的行是0则和上面的一行进行调换  
                    if (value[j][j]==0) {  
                        double[] temp = value[i];  
                        value[i] = value[i-1];  
                        value[i-1] = temp;  
                        result*=-1;  
                        continue;  
                    }  
                    double  ratio = -(value[i][j]/value[j][j]);  
                    value[i] = addValue(value[i],value[j],ratio);  
                }  
            }           
            return mathValue(value,result);
        }  

        /** 
         * 计算行列式的结果 
         * @param value 
         * @return 
         */  
        private static double mathValue(double[][] value,double result){  
            for (int i = 0; i < value.length; i++) {  
                //如果对角线上有一个值为0则全部为0,直接返回结果  
                if (value[i][i]==0) {  
                    return 0;  
                }  
                result *= value[i][i];  
            }  
            return result;  
        }  

        /*** 
         * 将i行之前的每一行乘以一个系数,使得从i行的第i列之前的数字置换为0 
         * @param currentRow 当前要处理的行 
         * @param frontRow i行之前的遍历的行 
         * @param ratio 要乘以的系数 
         * @return 将i行i列之前数字置换为0后的新的行 
         */  
        private static double[] addValue(double[] currentRow,double[] frontRow, double ratio){  
            for (int i = 0; i < currentRow.length; i++) {  
                currentRow[i] += frontRow[i]*ratio;  
            }  
            return currentRow;  
        }  

        /** 
         * 指定列的位置是否为0,查找第一个不为0的位置的行进行位置调换,如果没有则返回原来的值 
         * @param determinant 需要处理的行列式 
         * @param line 要调换的行 
         * @param row 要判断的列 
         */  
        private static double[][] changeDeterminantNoZero(double[][] determinant,int line,int column){  
            for (int i = line; i < determinant.length; i++) {  
                //进行行调换  
                if (determinant[i][column] != 0) {  
                    double[] temp = determinant[line];  
                    determinant[line] = determinant[i];  
                    determinant[i] = temp;  
                    return determinant;  
                }  
            }  
            return determinant;  
        }         

        /** 
         * 转换为代数余子式 
         * @param row 行 
         * @param line 列 
         * @param matrix 要转换的矩阵 
         * @return 转换的代数余子式 
         */  
        private static double[][] getNewMatrix(int row,int line,double[][] matrix){  
            double[][] newMatrix = new double[matrix.length-1][matrix[0].length-1];  
            int rowNum = 0,lineNum = 0;  
            for (int i = 0; i < matrix.length; i++) {  
                if (i == row){  
                    continue;  
                }  
                for (int j = 0; j < matrix[i].length; j++) {  
                    if (j == line) {  
                        continue;  
                    }  
                    newMatrix[rowNum][lineNum++%(matrix[0].length-1)] = matrix[i][j];  
                }  
                rowNum++;  
            }  
            return newMatrix;  
        }  

        public static void main(String[] args) {  
            //double[][] test = {{0,0,0,1,2},{0,0,0,2,3},{1,1,0,0,0},{0,1,1,0,0},{0,0,1,0,0}};  
            double[][] test = {
                    {3.8067488033632655, -2.894113667134647},  
                    {-2.894113667134647, 3.6978894069779504}
                };
            double result;  
            try {  
                double[][] temp = new double[test.length][test[0].length];  
                for (int i = 0; i < test.length; i++) {  
                    for (int j = 0; j < test[i].length; j++) {  
                        temp[i][j] = test[i][j];  
                    }  
                }  
                //先计算矩阵的行列式的值是否等于0,如果不等于0则该矩阵是可逆的  
                result = mathDeterminantCalculation(temp);  
                if (result == 0) {  
                    System.out.println("矩阵不可逆");  
                }else {  
                    System.out.println("矩阵可逆");  
                    //求出逆矩阵  
                    double[][] result11 = getInverseMatrix(test,result);  
                    //打印逆矩阵  
                    for (int i = 0; i < result11.length; i++) {  
                        for (int j = 0; j < result11[i].length; j++) {  
                            System.out.print(result11[i][j]+"   ");                       
                        }  
                        System.out.println();  
                    }  
                }  
            } catch (Exception e) {  
                e.printStackTrace();  
                System.out.println("不是正确的行列式!!");  
            }  
        }   
}
  
  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值