上一篇写了最小二乘回归树(Least Squares Regression Tree),本篇进阶一点,写一下 提升树(Boosting Decision Tree),详细解释可参考 https://www.jianshu.com/p/005a4e6ac775 的第二部分。
还是先上代码,先借用上一篇的代码,生成节点代码中,索引不再需要了,因为提升树只需要一层二叉树。
/**
* 依据给定的X和Y数据,基于最小二乘回归树生成 1 个二叉树(1个节点)
* 选择最优的切分点
* @param xdata
* @param ydata
* @return
*/
BinaryTreeNode generateRegressTreeNode(double[] xdata, double[] ydata) {
BinaryTreeNode brn = null;
int dataLength = xdata.length;
double minSum = 0;
// 遍历输入值,将xdata分为2个部分
for (int i = 0; i < dataLength; i++) {
// X数据的每一个值都可作为切分点
double splitPoint = xdata[i];
int[] r1Idx = new int[dataLength];
int[] r2Idx = new int[dataLength];
for (int j = 0; j < dataLength; j++) {
r1Idx[j] = -1;
r2Idx[j] = -1;
if (xdata[j] > splitPoint) {
r2Idx[j] = j;
} else {
r1Idx[j] = j;
}
}
// 切分点左侧Y的数据均值
double c1 = meanMatrix1DByIdx(ydata, r1Idx);
// 切分点右侧Y的数据均值
double c2 = meanMatrix1DByIdx(ydata, r2Idx);
// 左侧和右侧值的差平方和
double sumsl = sumSquareLoss(ydata, c1, r1Idx, c2, r2Idx);
// 找最小的和(冒泡方式)
if (i == 0 || minSum > sumsl) {
minSum = sumsl;
brn = new BinaryTreeNode(splitPoint, c1, c2);
//brn.setLeftIdx(r1Idx);
//brn.setRightIdx(r2Idx); // 索引不再需要了,因为只需要一层
}
}
return brn;
}
它的训练和预测的逻辑就简单一些了
List<BinaryTreeNode> binaryTreeNodeList; // 所有节点的列表
public BoostingDecisionTree(){
binaryTreeNodeList = new ArrayList<BinaryTreeNode>();
}
/**
* 回归的提升决策树算法
* 依据提升树算法进行预测计算
* @param newXData
* @return
*/
double[] predict(double[] newXData) {
double[] ret = new double[newXData.length];
for (int i = 0; i < newXData.length; i++) {
ret[i] = 0;
for (int j = 0; j < binaryTreeNodeList.size(); j++) {
BinaryTreeNode btn = binaryTreeNodeList.get(j);
if (i == 0) {
System.out.println("BTN:" + btn);
}
if (newXData[i] > btn.getSplitPoint()) {
ret[i] = ret[i] + btn.getRightValue();
} else {
ret[i] = ret[i] + btn.getLeftValue();
}
}
}
return ret;
}
/**
* 基于回归树,根据深度,得到多级的二叉树
* @param xdata
* @param ydata
* @param level
* @return
*/
double[] train(double[] xdata, double[] ydata, int level) {
double[] ret = null;
for (int i = 0; i < level; i++) {
if (i == 0) {
ret = ydata;
}
// 每次的输出值(ydata,就是generateRegressTreeNode的第二个参数)都是上一次的输出值与本次输出值的差
BinaryTreeNode btn = generateRegressTreeNode(xdata, ret);
binaryTreeNodeList.add(btn);
ret = calculateByNode(xdata, ret, btn);
}
return ret;
}
// 计算2次输出值的差
double[] calculateByNode(double[] xdata, double[] ydata, BinaryTreeNode binaryTreeNode) {
double[] ret = new double[xdata.length];
for (int i = 0; i < xdata.length; i++) {
if (xdata[i] > binaryTreeNode.getSplitPoint()) {
ret[i] = ydata[i] - binaryTreeNode.getRightValue();
} else {
ret[i] = ydata[i] - binaryTreeNode.getLeftValue();
}
}
return ret;
}
最后该验证了
// 训练数据同上一篇
double[] xdata = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
double[] ydata = {5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05};
BoostingDecisionTree bdt = new BoostingDecisionTree();
bdt.train(xdata, ydata, 6);
// 预测输入也同上一篇
double[] nxdata = {8.6, 5.5 , 4.4, 3.4, 2.4};
double[] predictRT = bdt.predict(nxdata);
System.out.println("nxdata:" + Arrays.toString(nxdata));
System.out.println("预测结果predictRT :" + Arrays.toString(predictRT));
------------------------------------------------------------
最终结果:
nxdata: 8.6 5.5 4.4 3.4 2.4
预测结果predictRT: 8.9502 6.8197 6.8197 6.5516 5.8183
可以与上一篇对比结果。
总结,提升树利用回归树的方法只分 1 层树。但是会生成多级二叉树,每个树的输入值(xdata)是一样的,但是对应的输出值(ydata)是不一样的,每次的输出值(ydata)都是上一次的输出值与本次输出值的差。f0 选取 0,因此第一次的输出值就使用的是最初的 ydata,预测的时候,判断新的输入数据依次落入每级树的区间,然后将每级树对应的输出值累加。