最小二乘法多项式拟合的Java实现--转

原文地址:http://blog.csdn.net/funnyrand/article/details/46742561

背景

由于项目中需要根据磁盘的历史使用情况预测未来一段时间的使用情况,决定采用最小二乘法做多项式拟合,这里简单描述下:

 

假设给定的数据点和其对应的函数值为 (x1, y1), (x2, y2), ... (xm, ym),需要做的就是得到一个多项式函数

f(x) = a0 * x + a1 * pow(x, 2) + .. + an * pow(x, n),使其对所有给定x所计算出的f(x)与实际对应的y值的差的平方和最小,

也就是计算多项式的各项系数 a0, a1, ... an.

 

根据最小二乘法的原理,该问题可转换为求以下线性方程组的解:Ga = B

所以从编程的角度来说需要做两件事情,1,确定线性方程组的各个系数,2,解线性方程组

确定系数比较简单,对给定的 (x1, y1), (x2, y2), ... (xm, ym) 做相应的计算即可,相关代码:

private void compute() {

...

}

解线性方程组稍微复杂,这里用到了高斯消元法,基本思想是通过递归做矩阵转换,逐渐减少求解的多项式系数的个数,相关代码:

private double[] calcLinearEquation(double[][] a, double[] b) {

...

}

 

Java代码

[java]  view plain  copy
 
  1. package com.my.study.algorithm.leastSquareMethod;  
  2.   
  3. /** 
  4.  * Least square method class. 
  5.  */  
  6. public class LeastSquareMethod {  
  7.   
  8.     private double[] x;  
  9.     private double[] y;  
  10.     private double[] weight;  
  11.     private int n;  
  12.     private double[] coefficient;  
  13.   
  14.     /** 
  15.      * Constructor method. 
  16.      *  
  17.      * @param x 
  18.      *            Array of x 
  19.      * @param y 
  20.      *            Array of y 
  21.      * @param n 
  22.      *            The order of polynomial 
  23.      */  
  24.     public LeastSquareMethod(double[] x, double[] y, int n) {  
  25.         if (x == null || y == null || x.length < 2 || x.length != y.length  
  26.                 || n < 2) {  
  27.             throw new IllegalArgumentException(  
  28.                     "IllegalArgumentException occurred.");  
  29.         }  
  30.         this.x = x;  
  31.         this.y = y;  
  32.         this.n = n;  
  33.         weight = new double[x.length];  
  34.         for (int i = 0; i < x.length; i++) {  
  35.             weight[i] = 1;  
  36.         }  
  37.         compute();  
  38.     }  
  39.   
  40.     /** 
  41.      * Constructor method. 
  42.      *  
  43.      * @param x 
  44.      *            Array of x 
  45.      * @param y 
  46.      *            Array of y 
  47.      * @param weight 
  48.      *            Array of weight 
  49.      * @param n 
  50.      *            The order of polynomial 
  51.      */  
  52.     public LeastSquareMethod(double[] x, double[] y, double[] weight, int n) {  
  53.         if (x == null || y == null || weight == null || x.length < 2  
  54.                 || x.length != y.length || x.length != weight.length || n < 2) {  
  55.             throw new IllegalArgumentException(  
  56.                     "IllegalArgumentException occurred.");  
  57.         }  
  58.         this.x = x;  
  59.         this.y = y;  
  60.         this.n = n;  
  61.         this.weight = weight;  
  62.         compute();  
  63.     }  
  64.   
  65.     /** 
  66.      * Get coefficient of polynomial. 
  67.      *  
  68.      * @return coefficient of polynomial 
  69.      */  
  70.     public double[] getCoefficient() {  
  71.         return coefficient;  
  72.     }  
  73.   
  74.     /** 
  75.      * Used to calculate value by given x. 
  76.      *  
  77.      * @param x 
  78.      *            x 
  79.      * @return y 
  80.      */  
  81.     public double fit(double x) {  
  82.         if (coefficient == null) {  
  83.             return 0;  
  84.         }  
  85.         double sum = 0;  
  86.         for (int i = 0; i < coefficient.length; i++) {  
  87.             sum += Math.pow(x, i) * coefficient[i];  
  88.         }  
  89.         return sum;  
  90.     }  
  91.   
  92.     /** 
  93.      * Use Newton's method to solve equation. 
  94.      *  
  95.      * @param y 
  96.      *            y 
  97.      * @return x 
  98.      */  
  99.     public double solve(double y) {  
  100.         return solve(y, 1.0d);  
  101.     }  
  102.   
  103.     /** 
  104.      * Use Newton's method to solve equation. 
  105.      *  
  106.      * @param y 
  107.      *            y 
  108.      * @param startX 
  109.      *            The start point of x 
  110.      * @return x 
  111.      */  
  112.     public double solve(double y, double startX) {  
  113.         final double EPS = 0.0000001d;  
  114.         if (coefficient == null) {  
  115.             return 0;  
  116.         }  
  117.         double x1 = 0.0d;  
  118.         double x2 = startX;  
  119.         do {  
  120.             x1 = x2;  
  121.             x2 = x1 - (fit(x1) - y) / calcReciprocal(x1);  
  122.         } while (Math.abs((x1 - x2)) > EPS);  
  123.         return x2;  
  124.     }  
  125.   
  126.     /* 
  127.      * Calculate the reciprocal of x. 
  128.      *  
  129.      * @param x x 
  130.      *  
  131.      * @return the reciprocal of x 
  132.      */  
  133.     private double calcReciprocal(double x) {  
  134.         if (coefficient == null) {  
  135.             return 0;  
  136.         }  
  137.         double sum = 0;  
  138.         for (int i = 1; i < coefficient.length; i++) {  
  139.             sum += i * Math.pow(x, i - 1) * coefficient[i];  
  140.         }  
  141.         return sum;  
  142.     }  
  143.   
  144.     /* 
  145.      * This method is used to calculate each elements of augmented matrix. 
  146.      */  
  147.     private void compute() {  
  148.         if (x == null || y == null || x.length <= 1 || x.length != y.length  
  149.                 || x.length < n || n < 2) {  
  150.             return;  
  151.         }  
  152.         double[] s = new double[(n - 1) * 2 + 1];  
  153.         for (int i = 0; i < s.length; i++) {  
  154.             for (int j = 0; j < x.length; j++) {  
  155.                 s[i] += Math.pow(x[j], i) * weight[j];  
  156.             }  
  157.         }  
  158.         double[] b = new double[n];  
  159.         for (int i = 0; i < b.length; i++) {  
  160.             for (int j = 0; j < x.length; j++) {  
  161.                 b[i] += Math.pow(x[j], i) * y[j] * weight[j];  
  162.             }  
  163.         }  
  164.         double[][] a = new double[n][n];  
  165.         for (int i = 0; i < n; i++) {  
  166.             for (int j = 0; j < n; j++) {  
  167.                 a[i][j] = s[i + j];  
  168.             }  
  169.         }  
  170.   
  171.         // Now we need to calculate each coefficients of augmented matrix  
  172.         coefficient = calcLinearEquation(a, b);  
  173.     }  
  174.   
  175.     /* 
  176.      * Calculate linear equation. 
  177.      *  
  178.      * The matrix equation is like this: Ax=B 
  179.      *  
  180.      * @param a two-dimensional array 
  181.      *  
  182.      * @param b one-dimensional array 
  183.      *  
  184.      * @return x, one-dimensional array 
  185.      */  
  186.     private double[] calcLinearEquation(double[][] a, double[] b) {  
  187.         if (a == null || b == null || a.length == 0 || a.length != b.length) {  
  188.             return null;  
  189.         }  
  190.         for (double[] x : a) {  
  191.             if (x == null || x.length != a.length)  
  192.                 return null;  
  193.         }  
  194.   
  195.         int len = a.length - 1;  
  196.         double[] result = new double[a.length];  
  197.   
  198.         if (len == 0) {  
  199.             result[0] = b[0] / a[0][0];  
  200.             return result;  
  201.         }  
  202.   
  203.         double[][] aa = new double[len][len];  
  204.         double[] bb = new double[len];  
  205.         int posx = -1, posy = -1;  
  206.         for (int i = 0; i <= len; i++) {  
  207.             for (int j = 0; j <= len; j++)  
  208.                 if (a[i][j] != 0.0d) {  
  209.                     posy = j;  
  210.                     break;  
  211.                 }  
  212.             if (posy != -1) {  
  213.                 posx = i;  
  214.                 break;  
  215.             }  
  216.         }  
  217.         if (posx == -1) {  
  218.             return null;  
  219.         }  
  220.   
  221.         int count = 0;  
  222.         for (int i = 0; i <= len; i++) {  
  223.             if (i == posx) {  
  224.                 continue;  
  225.             }  
  226.             bb[count] = b[i] * a[posx][posy] - b[posx] * a[i][posy];  
  227.             int count2 = 0;  
  228.             for (int j = 0; j <= len; j++) {  
  229.                 if (j == posy) {  
  230.                     continue;  
  231.                 }  
  232.                 aa[count][count2] = a[i][j] * a[posx][posy] - a[posx][j]  
  233.                         * a[i][posy];  
  234.                 count2++;  
  235.             }  
  236.             count++;  
  237.         }  
  238.   
  239.         // Calculate sub linear equation  
  240.         double[] result2 = calcLinearEquation(aa, bb);  
  241.   
  242.         // After sub linear calculation, calculate the current coefficient  
  243.         double sum = b[posx];  
  244.         count = 0;  
  245.         for (int i = 0; i <= len; i++) {  
  246.             if (i == posy) {  
  247.                 continue;  
  248.             }  
  249.             sum -= a[posx][i] * result2[count];  
  250.             result[i] = result2[count];  
  251.             count++;  
  252.         }  
  253.         result[posy] = sum / a[posx][posy];  
  254.         return result;  
  255.     }  
  256.   
  257.     public static void main(String[] args) {  
  258.         LeastSquareMethod eastSquareMethod = new LeastSquareMethod(  
  259.                 new double[] { 0.5, 1.0, 1.5, 2.0, 2.5, 3.0 }, new double[] {  
  260.                         1.75, 2.45, 3.81, 4.8, 7.0, 8.6 }, 3);  
  261.         /*double[] coefficients = eastSquareMethod.getCoefficient(); 
  262.         for (double c : coefficients) { 
  263.             System.out.println(c); 
  264.         }*/  
  265.   
  266.         System.out.println(eastSquareMethod.fit(4));  
  267.   
  268.         LeastSquareMethod eastSquareMethod2 = new LeastSquareMethod(  
  269.                 new double[] { 0.5, 1.0, 1.5, 2.0, 2.5, 3.0 }, new double[] {  
  270.                         1.75, 2.45, 3.81, 4.8, 7.0, 8.6 }, 2);  
  271.         System.out.println(eastSquareMethod2.solve(100));  
  272.   
  273.     }  
  274. }  


使用开源库

也可使用Apache开源库commons math,提供的功能更强大,

http://commons.apache.org/proper/commons-math/userguide/fitting.html

 

[html]  view plain  copy
 
  1. <dependency>  
  2.             <groupId>org.apache.commons</groupId>  
  3.             <artifactId>commons-math3</artifactId>  
  4.             <version>3.5</version>  
  5.         </dependency>  



 

 

[java]  view plain  copy
 
    1. private static void testLeastSquareMethodFromApache() {  
    2.         final WeightedObservedPoints obs = new WeightedObservedPoints();  
    3.         obs.add(-3, 4);  
    4.         obs.add(-2, 2);  
    5.         obs.add(-1, 3);  
    6.         obs.add(0, 0);  
    7.         obs.add(1, -1);  
    8.         obs.add(2, -2);  
    9.         obs.add(3, -5);  
    10.   
    11.         // Instantiate a third-degree polynomial fitter.  
    12.         final PolynomialCurveFitter fitter = PolynomialCurveFitter.create(3);  
    13.   
    14.         // Retrieve fitted parameters (coefficients of the polynomial function).  
    15.         final double[] coeff = fitter.fit(obs.toList());  
    16.         for (double c : coeff) {  
    17.             System.out.println(c);  
    18.         }  
    19.     }  

 

转载于:https://www.cnblogs.com/davidwang456/articles/5582752.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值