Java程序员学算法(2) - 提升树算法(Boosting Decision Tree)

       上一篇写了最小二乘回归树(Least Squares Regression Tree),本篇进阶一点,写一下 提升树(Boosting Decision Tree),详细解释可参考 https://www.jianshu.com/p/005a4e6ac775 的第二部分。

 还是先上代码,先借用上一篇的代码,生成节点代码中,索引不再需要了,因为提升树只需要一层二叉树。

/**
 * 依据给定的X和Y数据,基于最小二乘回归树生成 1 个二叉树(1个节点)
 * 选择最优的切分点
 * @param xdata
 * @param ydata
 * @return
 */
BinaryTreeNode generateRegressTreeNode(double[] xdata, double[] ydata) {
    BinaryTreeNode brn = null;
    int dataLength = xdata.length;
    double minSum = 0;
    // 遍历输入值,将xdata分为2个部分
    for (int i = 0; i < dataLength; i++) {
        // X数据的每一个值都可作为切分点
        double splitPoint = xdata[i];
            
        int[] r1Idx = new int[dataLength];
        int[] r2Idx = new int[dataLength];
            
        for (int j = 0; j < dataLength; j++) {
            r1Idx[j] = -1;
            r2Idx[j] = -1;
            if (xdata[j] > splitPoint) {
                r2Idx[j] = j;
            } else {
                r1Idx[j] = j;
            }
        }
        // 切分点左侧Y的数据均值
        double c1 = meanMatrix1DByIdx(ydata, r1Idx);
        // 切分点右侧Y的数据均值
        double c2 = meanMatrix1DByIdx(ydata, r2Idx);
            
        // 左侧和右侧值的差平方和
        double sumsl = sumSquareLoss(ydata, c1, r1Idx, c2, r2Idx);
            
        // 找最小的和(冒泡方式)
        if (i == 0 ||  minSum > sumsl) {
            minSum = sumsl;
            brn = new BinaryTreeNode(splitPoint, c1, c2);
            //brn.setLeftIdx(r1Idx);
            //brn.setRightIdx(r2Idx); // 索引不再需要了,因为只需要一层
        }
    }
    return brn;
}

它的训练和预测的逻辑就简单一些了

List<BinaryTreeNode> binaryTreeNodeList; // 所有节点的列表

public BoostingDecisionTree(){
   binaryTreeNodeList = new ArrayList<BinaryTreeNode>();
}
    
   
/**
  * 回归的提升决策树算法
  * 依据提升树算法进行预测计算
  * @param newXData
  * @return
  */
double[] predict(double[] newXData) {
    double[] ret = new double[newXData.length];
    for (int i = 0; i < newXData.length; i++) {
        ret[i] = 0;
        for (int j = 0; j < binaryTreeNodeList.size(); j++) {
            BinaryTreeNode btn = binaryTreeNodeList.get(j);
            if (i == 0) {
                System.out.println("BTN:" + btn);
            }
            if (newXData[i] > btn.getSplitPoint()) {
                ret[i] = ret[i] + btn.getRightValue();
            } else {
                ret[i] = ret[i] + btn.getLeftValue();
            }
     
        }
    }
    return ret;
}
/**
 * 基于回归树,根据深度,得到多级的二叉树
 * @param xdata
 * @param ydata
 * @param level
 * @return
 */
double[] train(double[] xdata, double[] ydata, int level) {
    double[] ret = null;
    for (int i = 0; i < level; i++) {
        if (i == 0) {
            ret = ydata;
        }
        // 每次的输出值(ydata,就是generateRegressTreeNode的第二个参数)都是上一次的输出值与本次输出值的差
        BinaryTreeNode btn = generateRegressTreeNode(xdata, ret);
        binaryTreeNodeList.add(btn);
        ret = calculateByNode(xdata, ret, btn);
     }

     return ret;
}
// 计算2次输出值的差
double[] calculateByNode(double[] xdata, double[] ydata, BinaryTreeNode binaryTreeNode) {
    double[] ret = new double[xdata.length];
        
    for (int i = 0; i < xdata.length; i++) {
        if (xdata[i] > binaryTreeNode.getSplitPoint()) {
            ret[i] = ydata[i] - binaryTreeNode.getRightValue();
        } else {
            ret[i] = ydata[i] - binaryTreeNode.getLeftValue();
        }
    }
   
    return ret;
}
    

最后该验证了

// 训练数据同上一篇
double[] xdata = {1,    2,    3,    4,    5,    6,    7,    8,    9,    10};
double[] ydata = {5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05};
        
BoostingDecisionTree bdt = new BoostingDecisionTree();

bdt.train(xdata, ydata, 6);

// 预测输入也同上一篇
double[] nxdata = {8.6, 5.5 , 4.4, 3.4, 2.4};
double[] predictRT = bdt.predict(nxdata);

System.out.println("nxdata:" + Arrays.toString(nxdata));
System.out.println("预测结果predictRT :" + Arrays.toString(predictRT));

------------------------------------------------------------
最终结果:
nxdata: 8.6  5.5  4.4  3.4  2.4  

预测结果predictRT: 8.9502  6.8197  6.8197  6.5516  5.8183  

可以与上一篇对比结果。

总结,提升树利用回归树的方法只分 1 层树。但是会生成多级二叉树每个树的输入值(xdata)是一样的,但是对应的输出值(ydata)是不一样的,每次的输出值(ydata)都是上一次的输出值与本次输出值的差。f0 选取 0,因此第一次的输出值就使用的是最初的 ydata,预测的时候,判断新的输入数据依次落入每级树的区间,然后将每级树对应的输出值累加。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值