Java程序员学算法(5) - Adam

        Adam在做线性回归时也比较常用的方法,很多时候比梯度下降好用。它的实现和梯度下降很类似,增加了超参数 beta1 和 beta2 控制了这些移动均值的衰减率,具体介绍可详见:https://www.cnblogs.com/yifdu25/p/8183587.html 。Adam带来的效果是:训练的时候,可以极度接近极限值,而不会再逐渐偏离。梯度下降会有可能发生偏离越来越大的情况。

如下的2个图就是对算法的描述,摘自上面的url。

 

有了这个图,我们就可以写代码了,它的实现与梯度下降类似,可以在梯度下降的代码基础上修改。

public class Adam {
    private static final Logger logger = LoggerFactory.getLogger(Adam.class);
 
    private double learningRate;
    private LinearFormula linearFormula;
    private boolean needBias = true;
     
    private double b1 = 0.9d;  //英文算法描述中的 b1
    private double b2 = 0.999d; //英文算法描述中的 b2
    private double e = Math.pow(10, -8); //英文算法描述中的 e (常数)
     
    public Adam(double learningRate, LinearFormula linearFormula, boolean needBias){
        this.learningRate = learningRate;
        this.linearFormula = linearFormula;
        this.needBias = needBias;
    }
    // 这个方法和上一篇 梯度下降的一致
    private double sum(double[] y, double[] p, double[] x){
        boolean isXNotEmpty = ArrayUtils.isNotEmpty(x);
        double sum = 0;
        for (int i = 0; i < y.length; i++){
            double v = p[i] - y[i];
            if (isXNotEmpty) {
                v = v * x[i];
            }
            sum = sum + v;
        }
        sum = sum / y.length;
        return sum;
    }
     
     // 这个方法也是和梯度下降的一致
    private double[] generateArray(int length, double pointValue){
        double[] ret = new double[length];
        for (int i = 0; i < length; i++){
            ret[i] = pointValue;
        }
        return ret;
    }
     
    public double[] train(double[] y, double expectMape, int trainCount, double[]... x){
        Map<Double, Double[]> map = new HashMap<Double, Double[]>();
        Map<Double, Double> mapeMap = new HashMap<Double, Double>();
         
        Map<Double, Double[]> forecastMap = new HashMap<Double, Double[]>();
 
        double[] lsa = new double[trainCount];
        int length = y.length;
        if (length == 0) {
            logger.error("wrong-------------------------length: " + length);
        }
        int xlength = x.length;
        int alength = xlength + 1; // 参数a的数量是变量个数+1
        logger.info("length: " + length + ", xlength: " + xlength + ", alength: " + alength);
         
        double[] a = generateArray(alength, 0.5d);
        a[0] = 0.0d;
        double[] m = generateArray(alength, 0.0d);
        double[] v = generateArray(alength, 0.0d);
 
//        double alpha = learningRate * Math.sqrt(1 - b2) / (1 - b1);
        double minMape = 1.0d;
        for (int i = 0 ; i < trainCount; i++){
            linearFormula.setParameters(a);
            double[] p = linearFormula.calculateResult(x);
            Double[] ao = new Double[alength];
            for (int j = 0; j < alength; j++){
                ao[j] = a[j];
            }
            if (i % 5000 == 0  || i < 5) {
                MatrixHelper.outputMatrix1(i + " a", a);
            }
            double stepb1 = Math.pow(b1, (double) (i + 1)); // 参考tensorflow
            double stepb2 = Math.pow(b2, (double) (i + 1)); // 参考tensorflow
            // 算法循环
            for (int j = 0; j < alength; j++){
                int xIdx = j - 1;
                if (needBias == false && xIdx < 0) {
                    continue;
                }
                double g = 0.0d;
                if (j == 0) {
                    g = sum(y, p, null); // 计算 a0 的值,也就是没有变量的那个参数。例如:y = a0 + a1 * x1
                } else {
                    g = sum(y, p, x[xIdx]);
                }
                 
                m[j] = b1 * m[j] + (1 - b1) * g;
                v[j] = b2 * v[j] + (1 - b2) * g * g;
                double mtemp = m[j] / (1 - stepb1);
                double vtemp = v[j] / (1 - stepb2);
                a[j] = a[j] - learningRate * mtemp / (Math.sqrt(vtemp) + e);
//                a[j] = a[j] - alpha * m[j] / (Math.sqrt(v[j]) + e);
            }
            lsa[i] = calculateLeastSquare(y, p);
            map.put(lsa[i], ao);
             
            Double[] po = new Double[p.length];
            int iii = 0;
            for (double pv : p) {
                po[iii++] = pv;
            }
            forecastMap.put(lsa[i], po);
             
            if (expectMape >= 0) {
                double mape = getMAPE(y, p);
                if (minMape > mape) {
                    minMape = mape;
                }
                mapeMap.put(lsa[i], mape);
                if (expectMape >= mape) { // 达到预期,不再训练
                    logger.info("reach the expected result, end.");
                    break;
                }
            }
        }
 
        double min = getMinValueFrom1D(lsa);
        logger.info("min: " + min + ", mape: " + mapeMap.get(min));
         
        double[] p = new double[y.length];
        int iii = 0;
        for (Double pv : forecastMap.get(min)) {
            p[iii++] = pv.doubleValue();
        }
 
        Double[] finalA = map.get(min);
        double[] ret = new double[finalA.length];
        int idx = 0;
        for (Double fa : finalA){
            ret[idx] = 0.0;
            if (fa != null){
                ret[idx] = fa.doubleValue();
            }
            idx++;
        }
        logger.info("ret: " + ret[0] + ", ret1: " + ret[1] + ", minMape: " + minMape + ", expectMape: " + expectMape);
 
        return ret;
    }
    public static double getMAPE(double[] source, double[] predict){
      Validate.isTrue(source.length > 0,
            "The array length of the parameters source and predict should be more than 0.");
      Validate.isTrue(source.length == predict.length,
            "The array length of the parameters source and predict should be same. source.length: " + source.length +
            ", predict.length: " + predict.length);
      double ret = 0.0;
        
      for (int i = 0; i < source.length; i++){
         double base = source[i];
         if (base == 0){
            base = 1;
         }
         ret = ret + Math.abs((source[i] - predict[i]) / base);
      }
      ret = ret / source.length;
      return ret;
   }
     
   public static double getTopDifference(double[] source, double[] predict, int topNumber){
      Validate.isTrue(topNumber > 0 && topNumber < 100,
            "The parameter topNumber should be more than 0 and less then 100.");
        
      Validate.isTrue(source.length > 0,
            "The array length of the parameters source and predict should be more than 0.");
      Validate.isTrue(source.length == predict.length,
            "The array length of the parameters source and predict should be same.");
      double ret = 0.0;
      double[] difference = new double[source.length];
      for (int i = 0; i < source.length; i++){
         double base = source[i];
         if (base == 0){
            base = 1;
         }
         difference[i] = Math.abs((source[i] - predict[i]) / base);
      }
        
      Arrays.sort(difference); // order by ascent
        
      double percent = ((double) topNumber) / 100.0d;
        
      int idx = (int)(percent * (source.length - 1));
      ret = difference[idx];
      return ret;
   }
   public static double getMinValueFrom1D(double[] xa){
       double ret = xa[0];
       for (double d : xa){
           if (d < ret){
               ret = d;
           }
       }
       return ret;
   }
   // 计算最小二乘的结果
    public static double calculateLeastSquare(double[] y, double[] p){
       double ret = 0;
       for (int i = 0; i < y.length; i++){
           double lack = y[i] - p[i];
           ret += (lack * lack);
       }
         
       ret = ret / (2 * y.length);
  
       return ret;
   }
}

这个算法的测试方法可以用 梯度下降的。可参考 Java程序员学算法(4) - 梯度下降(Gradient Descent)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值