本篇继续进阶一点,写一下 梯度提升决策树(Gradient Boosting Decision Tree),详细解释可参考 GBDT:梯度提升决策树 - 简书 的第三部分
还是先上代码,梯度提升决策树是能够支持多种损失函数的,关于 损失函数的定义,老规矩,自己搜。既然要支持多种损失函数,因此先写个接口类,然后再来个实现,后面会用到
损失函数接口类
public interface LossFunction {
public double loss(double y, double predict);
/**
* 负梯度
* 详见:《The Elements of Statistical Learning .pdf》 的第10章的表 TABLE 10.2 Gradients for commonly used loss functions.
* @param y 实际值
* @param predict 预测值
* @return
*/
public double negativePartialDerivative(double y, double predict);
}
损失函数的实现 - 差平方
public class SquareErrorLoss implements LossFunction{
@Override
public double loss(double y, double predict) {
double loss = y - predict;
return loss * loss;
}
/**
* 负梯度
* 详见:《The Elements of Statistical Learning .pdf》 的第10章的表 TABLE 10.2 Gradients for commonly used loss functions.
*/
@Override
public double negativePartialDerivative(double y, double predict) {
return y - predict;
}
}
生成树的节点的代码借用上一篇提升树的代码,计算损失函数之和部分调整了一下。
/**
* 依据给定的X和Y数据,基于最小二乘回归树生成 1 个二叉树(1个节点)
* 选择最优的切分点
* @param xdata
* @param ydata
* @return
*/
BinaryTreeNode generateRegressTreeNode(double[] xdata, double[] ydata) {
BinaryTreeNode brn = null;
int dataLength = xdata.length;
double minSum = 0;
// 遍历输入值,将xdata分为2个部分
for (int i = 0; i < dataLength; i++) {
// X数据的每一个值都可作为切分点
double splitPoint = xdata[i];
int[] r1Idx = new int[dataLength];
int[] r2Idx = new int[dataLength];
for (int j = 0; j < dataLength; j++) {
r1Idx[j] = -1;
r2Idx[j] = -1;
if (xdata[j] > splitPoint) {
r2Idx[j] = j;
} else {
r1Idx[j] = j;
}
}
// 切分点左侧Y的数据均值
double c1 = meanMatrix1DByIdx(ydata, r1Idx);
// 切分点右侧Y的数据均值
double c2 = meanMatrix1DByIdx(ydata, r2Idx);
// 左侧和右侧值的损失函数之和
double sumsl = sumLoss(ydata, c1, r1Idx, c2, r2Idx); // 更改的地方
// 找最小的和(冒泡方式)
if (i == 0 || minSum > sumsl) {
minSum = sumsl;
brn = new BinaryTreeNode(splitPoint, c1, c2);
//brn.setLeftIdx(r1Idx);
//brn.setRightIdx(r2Idx); // 索引不再需要了,因为只需要一层
}
}
return brn;
}
/**
* 偏差值使用损失函数实现
*/
double sumLoss(double[] data, double c1, int[] r1IndexArray, double c2, int[] r2IndexArray) {
double sum = 0.0d;
for (int idx : r1IndexArray) {
if (idx > -1) {
sum = sum + lossFunction.loss(data[idx], c1);
}
}
for (int idx : r2IndexArray) {
if (idx > -1) {
sum = sum + lossFunction.loss(data[idx], c2);
}
}
return sum;
}
它的训练和预测的逻辑比提升树复杂一点
LossFunction lossFunction; // 损失函数
List<BinaryTreeNode> binaryTreeNodeList;
double f0; // 初始值
public GradientBoostingDecisionTree(LossFunction lossFunction){
this.lossFunction = lossFunction;
binaryTreeNodeList = new ArrayList<BinaryTreeNode>();
}
/**
* 初始化时,估计使损失函数极小化的常数值,它是只有一个根节点的树,即gama是一个常数值。
* 在此取均值
* @param data
* @return
*/
private double retrieveArgmin(double[] data){
double total = 0.0d;
for (int i = 0; i < data.length; i++){
total = total + data[i];
}
return total / data.length;
}
/**
* 依据梯度提升树算法进行预测计算
* @param newXData
* @return
*/
double[] predict(double[] newXData) {
double[] ret = new double[newXData.length];
for (int i = 0; i < newXData.length; i++) {
ret[i] = f0;
for (int j = 0; j < binaryTreeNodeList.size(); j++) {
BinaryTreeNode btn = binaryTreeNodeList.get(j);
if (newXData[i] > btn.getSplitPoint()) {
ret[i] = ret[i] + btn.getRightValue();
} else {
ret[i] = ret[i] + btn.getLeftValue();
}
}
}
return ret;
}
/**
* 基于回归树,根据深度,得到多级的二叉树
* @param xdata
* @param ydata
* @param level
* @return
*/
void train(double[] xdata, double[] ydata, int level) {
f0 = retrieveArgmin(ydata);
double[] temp = null;
for (int i = 0; i < level; i++) {
if (i == 0) {
// 计算第一次差值
temp = calculateFirstResidual(ydata, f0);
}
BinaryTreeNode btn = generateRegressTreeNode(xdata, temp);
binaryTreeNodeList.add(btn);
temp = calculateByNode(xdata, temp, btn);
}
}
// 计算每一个Y与对比值fm的负梯度值
double[] calculateFirstResidual(double[] ydata, double fm) {
double[] ret = new double[ydata.length];
for (int i = 0; i < ydata.length; i++) {
ret[i] = lossFunction.negativePartialDerivative(ydata[i], fm);
}
}
// 计算每一个Y与切分点两侧值的负梯度值
double[] calculateByNode(double[] xdata, double[] ydata, BinaryTreeNode binaryTreeNode) {
double[] ret = new double[xdata.length];
for (int i = 0; i < xdata.length; i++) {
if (xdata[i] > binaryTreeNode.getSplitPoint()) {
ret[i] = lossFunction.negativePartialDerivative(ydata[i], binaryTreeNode.getRightValue());
} else {
ret[i] = lossFunction.negativePartialDerivative(ydata[i], binaryTreeNode.getLeftValue());
}
}
return ret;
}
最后该验证了
double[] xdata = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
double[] ydata = {5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05};
GradientBoostingDecisionTree gbdt = new GradientBoostingDecisionTree(new SquareErrorLoss());
gbdt.train(xdata, ydata, 6);
double[] nxdata = {8.6, 5.5, 4.4, 3.4, 2.4};
double[] predictRT = gbdt.predict(nxdata);
System.out.println("nxdata:" + Arrays.toString(nxdata));
System.out.println("预测结果predict SquareErrorLoss :" + Arrays.toString(predictRT));
------------------------------------------------------------
最终结果:
nxdata: 8.6 5.5 4.4 3.4 2.4
预测结果SquareErrorLoss: 8.9502 6.8197 6.8197 6.5516 5.8183
总结,梯度提升树(解决回归问题):上述2个都是基于最小二乘的,但是对于其它的 损失函数 就不适用了。而梯度提升就是通过近似的方式,能够支持多种 损失函数,要求损失函数能够做1阶偏导数。实现流程与提升树类似。f0 选用的是ydata的均值, 因此第一次的输出值就使用的是 ydata的每一个值和f0,通过损失函数的偏导数方法计算的结果,预测操作与提升树类似,差别是 f0 的选取。