Java程序员学算法(4) - 梯度下降(Gradient Descent)

   梯度下降是做线性回归时比较常用的方法,关于线性回归和梯度下降的详细介绍可详见:https://blog.csdn.net/xiazdong/article/details/7950084 ,这里用到的数学知识比较多了,推导过程真心看不懂了,不过幸好最终的公式(文章最后的公式)还能看个大概,依葫芦画瓢还能写成Code。其实里面有个重要的概念 Cost Function,而选用是 最小二乘法,就是为了对比线性公式计算后的值与实际值直接的偏离。

       在线性回归中,常用到多项式也就是多变量的情况,比如,一个公式 y = a1 * x1 + a2 * x2 + ...... + b 。其中 y 和 各个x都是基于时间的已知数据,每一组x都有对应的y。各个a和b是未知的,而线性回归的流程就是先给出任意a和b的值,然后把每一组x代入公式,这样就能够求出yy ,最后需要对比 yy 和 y的偏差,追求yy和y的偏差小。每变换一组a和b的值就可以算出新的yy,然后再和y对比偏差,如此穷举循环,找出于实际y偏差最小的那个yy,此时对应的这组a和b的值就是最终结果,成为最终的完整公式,用于根据新的x来计算 y,也就相当于做预测。上述的流程就是穷举求极限的过程(极限小), Cost Function 选用最小二乘,那就是穷举得出 yy和y的最小二乘的极限小值,这个求极限就可以转变成 偏导数的问题了,这样最小二乘中的 平方 就变成一阶了 。最后就是上面提到blog里面的梯度下降公式了:

 而本篇的内容将上面的数学公式变成Java Code,首先创建用于多项式的接口

公式接口
public interface LinearFormula {
    // 设置 参数值,也就是 a
    void setParameters(double[] paramters);
    // 根据 x 计算
    double[] calculateResult(double[]... x);
}

然后是 梯度下降的实现:

public class GradientDescent {
    private static final Logger logger = LoggerFactory.getLogger(GradientDescent.class);
 
    private double learningRate;
    private LinearFormula linearFormula;
    // 是否需要 bias (偏移量)也就是 公式里面的b
    private boolean needBias = true;
 
    public GradientDescent(double learningRate, LinearFormula linearFormula, boolean needBias){
        this.learningRate = learningRate;
        this.linearFormula = linearFormula;
        this.needBias = needBias;
    }
     
    private double sum(double[] y, double[] p, double[] x){
        boolean isXNotEmpty = ArrayUtils.isNotEmpty(x);
        double sum = 0;
        for (int i = 0; i < y.length; i++){
            double v = p[i] - y[i];
            if (isXNotEmpty) {
                v = v * x[i];
            }
            sum = sum + v;
        }
        return sum;
    }
     
    private double[] generateArray(int length, double pointValue){
        double[] ret = new double[length];
        for (int i = 0; i < length; i++){
            ret[i] = pointValue;
        }
        return ret;
    }
     
    /**
     * 训练方法,训练数据的时调用此方法
     * @param y
     * @param expectMape 由于梯度下降会出现偏差最小之后,就行训练的话,会偏差变大的情况。也就是在一定的训练次数中,偏差会有变小再变大的情况。这个值就是指定期望的Mape指标,到达此值的时候就退出,不再继续训练,也就是符合期望了。
     * @param trainCount
     * @param x
     * @return 就是 公式中的各个参数 a 和b
     */
    public double[] train(double[] y, double expectMape, int trainCount, double[]... x){
        Map<Double, Double[]> map = new HashMap<Double, Double[]>();
        Map<Double, Double> mapeMap = new HashMap<Double, Double>();
         
        Map<Double, Double[]> forecastMap = new HashMap<Double, Double[]>();
 
        double[] lsa = new double[trainCount];
        int length = y.length;
        if (length == 0) {
            logger.error("wrong-------------------------length: " + length);
        }
        int xlength = x.length;
        int alength = xlength + 1; // 参数a的数量是变量个数+1
        logger.info("length: " + length + ", xlength: " + xlength + ", alength: " + alength);
         
        double[] a = generateArray(alength, 1.0d);
 
        double minMape = 1.0d;
        for (int i = 0 ; i < trainCount; i++){
            linearFormula.setParameters(a); // 每次训练都是重复制定新参数,然后计算结果,最后比较
            double[] p = linearFormula.calculateResult(x);
            Double[] ao = new Double[alength];
            for (int j = 0; j < alength; j++){
                ao[j] = a[j];
            }
             
            if (needBias) {
                a[0] = a[0] - (learningRate / length) * sum(y, p, null);// 计算 a0 的值,也就是没有变量的那个参数。例如:y = a0 + a1 * x1
            }
            for (int j = 0; j < xlength; j++){
                a[j + 1] = a[j + 1] - (learningRate / length) * sum(y, p, x[j]);
            }
            lsa[i] = calculateLeastSquare(y, p);
            map.put(lsa[i], ao);
             
            Double[] po = new Double[p.length];
            int iii = 0;
            for (double pv : p) {
                po[iii++] = pv;
            }
            forecastMap.put(lsa[i], po);
             
            if (expectMape >= 0) {
                double mape = getMAPE(y, p);
                if (minMape > mape) {
                    minMape = mape;
                }
                mapeMap.put(lsa[i], mape);
                if (expectMape >= mape) { // 达到预期,不再训练
                    logger.info("reach the expected result, end.");
                    break;
                }
            }
        }
        double min = getMinValueFrom1D(lsa);
        logger.info("min: " + min + ", mape: " + mapeMap.get(min));
         
        double[] p = new double[y.length];
        int iii = 0;
        for (Double pv : forecastMap.get(min)) {
            p[iii++] = pv.doubleValue();
        }
 
        Double[] finalA = map.get(min);
        double[] ret = new double[finalA.length];
        int idx = 0;
        for (Double fa : finalA){
            ret[idx] = 0.0;
            if (fa != null){
                ret[idx] = fa.doubleValue();
            }
            idx++;
        }
        logger.info("ret: " + ret[0] + ", ret1: " + ret[1] + ", minMape: " + minMape + ", expectMape: " + expectMape);
        return ret;
    }
     
    /**
     * 工具方法,用于根据偏移值offset打乱 输入 sa,提供测试数据用
     * @param da
     * @param offset
     * @return
     */
    public static double[] confusion(double[] da, double offset){
        double[] ret = new double[da.length];
        for (int i = 0; i < da.length; i++){
            double random = Math.random();
            double flag = ((int)(random * 10)) % 2 == 0 ? 1 : -1;
            double tmp = da[i] * offset;
            if (random >= da[i] * offset){
                ret[i] = da[i] + flag * tmp;
            } else {
                ret[i] = da[i] + (da[i] * flag * random);
            }
        }
         
        return ret;
    }
    // 计算最小二乘的结果
    public static double calculateLeastSquare(double[] y, double[] p){
       double ret = 0;
       for (int i = 0; i < y.length; i++){
           double lack = y[i] - p[i];
           ret += (lack * lack);
       }
        
       ret = ret / (2 * y.length);
 
       return ret;
   }
    
   public static double getMAPE(double[] source, double[] predict){
      Validate.isTrue(source.length > 0,
            "The array length of the parameters source and predict should be more than 0.");
      Validate.isTrue(source.length == predict.length,
            "The array length of the parameters source and predict should be same. source.length: " + source.length +
            ", predict.length: " + predict.length);
      double ret = 0.0;
       
      for (int i = 0; i < source.length; i++){
         double base = source[i];
         if (base == 0){
            base = 1;
         }
         ret = ret + Math.abs((source[i] - predict[i]) / base);
      }
      ret = ret / source.length;
      return ret;
   }
    
   public static double getTopDifference(double[] source, double[] predict, int topNumber){
      Validate.isTrue(topNumber > 0 && topNumber < 100,
            "The parameter topNumber should be more than 0 and less then 100.");
       
      Validate.isTrue(source.length > 0,
            "The array length of the parameters source and predict should be more than 0.");
      Validate.isTrue(source.length == predict.length,
            "The array length of the parameters source and predict should be same.");
      double ret = 0.0;
      double[] difference = new double[source.length];
      for (int i = 0; i < source.length; i++){
         double base = source[i];
         if (base == 0){
            base = 1;
         }
         difference[i] = Math.abs((source[i] - predict[i]) / base);
      }
       
      Arrays.sort(difference); // order by ascent
       
      double percent = ((double) topNumber) / 100.0d;
       
      int idx = (int)(percent * (source.length - 1));
      ret = difference[idx];
      return ret;
   }
   public static double getMinValueFrom1D(double[] xa){
       double ret = xa[0];
       for (double d : xa){
           if (d < ret){
               ret = d;
           }
       }
       return ret;
   }
}

下面就可以开始测试了,例子是 高斯分布公式(GaussianDistribution) ,如下是高斯分布的实现类

高斯分布的实现类
public abstract class BaseLinearFormula implements LinearFormula{
 
    private double[] parameters;
 
    public double[] getParameters() {
        return parameters;
    }
 
    @Override
    public void setParameters(double[] parameters) {
        validateParameters(parameters);
        this.parameters = parameters;
    }
     
    @Override
    public double[] calculateResult(double[]... x) {
         
        validateInputData(x);
 
        double[] ret = new double[x[0].length];
 
        for (int i = 0; i < x[0].length; i++){
            double[] inputX = new double[x.length];
            for (int j = 0; j < x.length; j++){
                inputX[j] = x[j][i];
            }
            ret[i] = retrieveValue(parameters, inputX);
        }
         
        return ret;
         
    }
     
    protected abstract double retrieveValue(double[] parameters, double[] data);
  
    protected abstract void validateInputData(double[]... x) throws IllegalArgumentException;
     
    protected abstract void validateParameters(double[] parameters) throws IllegalArgumentException;
     
}
 
 
/**
 * 高斯分布的实现,用于测试用。
 */
public class GaussianDistribution extends BaseLinearFormula{
    
   public static double[] generateGaussianDistributioSampleData(int n){
 
      int times = 12;
      double[] x = new double[n * times];
      double tmp = 0d;
      for (int i = 0; i < n * times; i++){
         tmp = tmp + (1.0 / times);
         x[i] = tmp - n / 2;
      }
       
      return x;
   }
    
    @Override
    protected double retrieveValue(double[] parameters, double[] data) {
        double ret = 0;
        double exponent = -1 * (Math.pow((data[0] - parameters[0]), 2) / (2 * Math.pow(parameters[1], 2)));
        ret = (1 / (parameters[1] * Math.sqrt(2 * Math.PI))) * Math.exp(exponent);
         
        return ret;
    }
 
    @Override
    protected void validateInputData(double[]... x) throws IllegalArgumentException {
        int variableNumber = x.length;
        Validate.isTrue(variableNumber == 1, "Gaussian Distribution should have 1 variable");
    }
 
    @Override
    protected void validateParameters(double[] parameters) throws IllegalArgumentException {
        Validate.isTrue(parameters.length == 2, "Gaussian Distribution should have 2 parameters");
    }
}

 然后就是测试了

// 高斯分布的例子不是很好,但可用于参考用法
GaussianDistribution gaussianDistribution = new GaussianDistribution();
GradientDescent gradientDescent = new GradientDescent(0.1, gaussianDistribution, false);
 
// 生成测试数据和参数
double[] xtest = GaussianDistribution.generateGaussianDistributioSampleData(6);
double[] a = {0, 0.6};
 
gaussianDistribution.setParameters(a);
 
// 根据公式计算结果
double[] sy = gaussianDistribution.calculateResult(xtest);
// 将刚刚计算的值做混淆,然后对比偏差度。主要是为了后续测试用,要用混淆后的数据做测试
double[] cy = GradientDescent.confusion(sy, 0.2);
System.out.println("confusion MAPE: " + GradientDescent.getMAPE(cy, sy) +
         ", confusion top 90: " + GradientDescent.getTopDifference(cy, sy, 90) +
         ", confusion top 80: " + GradientDescent.getTopDifference(cy, sy, 80));
 
System.out.println("a: " + a[0] + ", " + a[1]);
// 基于混淆后的结果 和原输入 进行训练。结果是 公式的参数
double[] atest = gradientDescent.train(cy, 0.0, 6000, xtest);
System.out.println("atest: " + atest[0] + ", " + atest[1]);
gaussianDistribution.setParameters(atest); // 指定参数
double[] p = gaussianDistribution.calculateResult(xtest); // 对原数据进行计算
System.out.println("final predict GaussianDistribution: ", p, ", ");
// 打印偏差
System.out.println("final MAPE: " + GradientDescent.getMAPE(cy, p) +
         ", final top 90: " + GradientDescent.getTopDifference(cy, p, 90) +
         ", final top 80: " + GradientDescent.getTopDifference(cy, p, 80));
 
 
--------效果不好,供参考-----------------------------------------------------------------------------------------------------------------
输出结果
confusion MAPE: 0.19411157584986596, confusion top 90: 0.25000000000000006, confusion top 80: 0.25
a: 0.0, 0.6
 
 
偏差比较大
final predict GaussianDistribution: : 0.0058, 0.0069, 0.0081, 0.0094, 0.011, 0.0128, 0.0148, 0.017, 0.0196, 0.0224, 0.0256, 0.0291, 0.0329, 0.0372, 0.0418, 0.0469, 0.0524, 0.0583, 0.0647, 0.0715, 0.0788, 0.0865, 0.0946, 0.1031, 0.112, 0.1212, 0.1307, 0.1405, 0.1505, 0.1606, 0.1708, 0.181, 0.1911, 0.2011, 0.2109, 0.2204, 0.2295, 0.2381, 0.2462, 0.2537, 0.2605, 0.2665, 0.2717, 0.276, 0.2794, 0.2819, 0.2834, 0.2839, 0.2834, 0.2819, 0.2794, 0.276, 0.2717, 0.2665, 0.2605, 0.2537, 0.2462, 0.2381, 0.2295, 0.2204, 0.2109, 0.2011, 0.1911, 0.181, 0.1708, 0.1606, 0.1505, 0.1405, 0.1307, 0.1212, 0.112, 0.1031,
final MAPE: 1429.2300451464919, final top 90: 988.6810056795385, final top 80: 170.0713463977803

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值