在算法如此时髦的当前,作为一名工程方面的Java程序员,同时也是算法小白,有着追赶时髦的好奇心,尝试着向算法的大门张望一下,因此,在此写一下在门缝中看到的东西。先从能看懂的CART的最小二乘回归树开始,并使用擅长的Java来实现。
首先,进入名词解释,CART( Classification And Regression Tree)和最小二乘法,它们的具体解释还请自己搜索吧。
最小二乘回归树:简单来说就是基于最小二乘的回归数。详细解释可参考 https://www.jianshu.com/p/005a4e6ac775 的第一部分。
然后进入代码阶段:
1、创建基础的对象,二叉树的节点- BinaryTreeNode ,其有五个属性
// 拆分点
double splitPoint;
// 拆分点左边的数据的均值,包含拆分点
double leftValue; // include the splitValue
// 拆分点左边元素的索引
int[] leftIdx;
// 拆分点左边的数据的均值
double rightValue;
// 拆分点右边元素的索引
int[] rightIdx;
2、开始实现逻辑
创建回归数的节点,如下为参考url的:算法5.5里面的公式5.21
/**
* 依据给定的X和Y数据,基于最小二乘回归树生成 1 个二叉树(1个节点)
* 选择最优的切分点
* @param xdata
* @param ydata
* @return
*/
BinaryTreeNode generateRegressTreeNode(double[] xdata, double[] ydata) {
BinaryTreeNode brn = null;
int dataLength = xdata.length;
double minSum = 0;
// 遍历输入值,将xdata分为2个部分
for (int i = 0; i < dataLength; i++) {
// X数据的每一个值都可作为切分点
double splitPoint = xdata[i];
int[] r1Idx = new int[dataLength];
int[] r2Idx = new int[dataLength];
for (int j = 0; j < dataLength; j++) {
r1Idx[j] = -1;
r2Idx[j] = -1;
if (xdata[j] > splitPoint) {
r2Idx[j] = j;
} else {
r1Idx[j] = j;
}
}
// 切分点左侧Y的数据均值
double c1 = meanMatrix1DByIdx(ydata, r1Idx);
// 切分点右侧Y的数据均值
double c2 = meanMatrix1DByIdx(ydata, r2Idx);
// 左侧和右侧值的差平方和
double sumsl = sumSquareLoss(ydata, c1, r1Idx, c2, r2Idx);
// 找最小的和(冒泡方式)
if (i == 0 || minSum > sumsl) {
minSum = sumsl;
brn = new BinaryTreeNode(splitPoint, c1, c2);
brn.setLeftIdx(r1Idx);
brn.setRightIdx(r2Idx);
}
}
return brn;
}
// 左侧和右侧值的差平方和
double sumSquareLoss(double[] data, double c1, int[] r1IndexArray, double c2, int[] r2IndexArray) {
double sum = 0.0d;
for (int idx : r1IndexArray) {
if (idx > -1) {
sum = sum + squareLoss(data[idx], c1);
}
}
for (int idx : r2IndexArray) {
if (idx > -1) {
sum = sum + squareLoss(data[idx], c2);
}
}
return sum;
}
double meanMatrix1DByIdx(double[] data, int[] indexArray) {
double ret = 0.0d;
int count = 0;
for (int idx : indexArray) {
if (idx > -1) {
ret = ret + data[idx];
count++;
}
}
ret = ret / count;
return ret;
}
// 差平方
double squareLoss(double y, double predict){
double loss = y - predict;
return loss * loss;
}
创建回归树
/**
* 生成回归树
* 递归方式
*
*/
void generateRegressTree(double[] xdata, double[] ydata, int currentLayer, int maxLayer){
if (currentLayer > maxLayer) {
return;
}
BinaryTreeNode btn = generateRegressTreeNode(xdata, ydata);
// 成员变量,记录所有的节点
binaryTreeNodeList.add(btn);
currentLayer++;
int divideLimit = 2; // 至少保证2个及以上的数据才能做分叉,这个决定节点的最小粒度
int leftCount = getValidIndexCount(btn.getLeftIdx());
if (leftCount >= divideLimit) {
double[] leftXData = getDataByIndex(xdata, leftCount, btn.getLeftIdx());
double[] leftYData = getDataByIndex(ydata, leftCount, btn.getLeftIdx());
generateRegressTree(leftXData, leftYData, currentLayer, maxLayer);
}
int rightCount = getValidIndexCount(btn.getRightIdx());
if (rightCount >= divideLimit) {
double[] rightXData = getDataByIndex(xdata, rightCount, btn.getRightIdx());
double[] rightYData = getDataByIndex(ydata, rightCount, btn.getRightIdx());
generateRegressTree(rightXData, rightYData, currentLayer, maxLayer);
}
}
double[] getDataByIndex(double[] data, int resultLength, int[] index) {
double[] ret = new double[resultLength];
int i = 0;
for (int idx : index) {
if (idx < 0) {
continue;
}
ret[i] = data[idx];
i++;
}
return ret;
}
int getValidIndexCount(int[] index) {
int ret = 0;
for (int idx : index) {
if (idx> -1) {
ret++;
}
}
return ret;
}
预测和训练逻辑,执行顺序是先训练再做预测
List<BinaryTreeNode> binaryTreeNodeList; // 所有节点的列表
double[] resultData; // 训练的结果,所有二叉树的
double[] splitPointArray;
public RegressionDecisionTree(){
binaryTreeNodeList = new ArrayList<BinaryTreeNode>();
}
/**
* 在训练的结果之上,依据输入的数据得出预测值
* @param newXData
* @return
*/
double[] predict(double[] newXData) {
double[] ret = new double[newXData.length];
for (int i = 0; i < newXData.length; i++) {
boolean valid = false;
// 寻找小于等于切分点和赋值对应的值
for (int j = 0; j < splitPointArray.length; j++) {
if (newXData[i] <= splitPointArray[j]) {
ret[i] = resultData[j];
valid = true;
break;
}
}
// 没找到的,给最大的。
if (valid == false) {
ret[i] = resultData[resultData.length - 1];
}
}
return ret;
}
/**
* 基于回归树,根据深度和给定的训练数据,得到多级的二叉树
* 详见参考的 算法5.5的 第(2)
* @param xdata
* @param ydata
* @param deep
*/
void train(double[] xdata, double[] ydata, int deep) {
generateRegressTree(xdata, ydata, 1, deep);
// calculate result average
splitPointArray = new double[binaryTreeNodeList.size()];
for (int j = 0; j < binaryTreeNodeList.size(); j++) {
BinaryTreeNode btn = binaryTreeNodeList.get(j);
splitPointArray[j] = btn.getSplitPoint();
}
// 将切分点值升序
Arrays.sort(splitPointArray);
// 每个切分点对一个值,并且依据从最小对齐原则,要给超过切分点最大值的位置,因此ResultData的length比切分点多一个,也就是最大的切分点会对应2个值
resultData = new double[splitPointArray.length + 1];
Map<Double, Double> pointAndTotal = new HashMap<Double, Double>();
Map<Double, Integer> pointAndCount = new HashMap<Double, Integer>();
int ext = 1;
for (int i = 0; i < xdata.length; i++) {
// 记录X小于 最小切分点值 的对应Y的和,以及数量,Key为最小切分点值
if (xdata[i] < splitPointArray[0]) {
assembleMapDoubleValue(pointAndTotal, splitPointArray[0], ydata[i]);
assembleMapIntegerValue(pointAndCount, splitPointArray[0]);
continue;
}
// 记录X大于等于 最大切分点值 的对应Y的和,以及数量,
// Key为最大切分点值+1 ,因为此处是额外多出来的。
if (xdata[i] >= splitPointArray[splitPointArray.length - 1]) {
assembleMapDoubleValue(pointAndTotal, splitPointArray[splitPointArray.length - 1] + ext, ydata[i]);
assembleMapIntegerValue(pointAndCount, splitPointArray[splitPointArray.length - 1] + ext);
continue;
}
// 记录X大于等于最小切分点值和X小于最大切分点值的对应Y的和,Key为切分点值
for (int j = 0; j < splitPointArray.length - 1; j++) {
if (xdata[i] >= splitPointArray[j] && xdata[i] < splitPointArray[j + 1]) {
assembleMapDoubleValue(pointAndTotal, splitPointArray[j + 1], ydata[i]);
assembleMapIntegerValue(pointAndCount, splitPointArray[j + 1]);
}
}
}
// 计算各个切分点值对应Y值的均值。
for (int i = 0; i < splitPointArray.length; i++) {
resultData[i] = pointAndTotal.get(splitPointArray[i]) / pointAndCount.get(splitPointArray[i]);
}
// 计算超过最大 切分点值的Y的均值。
resultData[resultData.length - 1] = pointAndTotal.get(splitPointArray[splitPointArray.length - 1] + ext)
/ pointAndCount.get(splitPointArray[splitPointArray.length - 1] + ext);
}
void assembleMapIntegerValue(Map<Double, Integer> map, double key) {
int result = 1;
if (map.containsKey(key)) {
result = map.get(key).intValue() + 1;
}
map.put(key, result);
}
void assembleMapDoubleValue(Map<Double, Double> map, double key, double value) {
double result = value;
if (map.containsKey(key)) {
result = map.get(key).doubleValue() + result;
}
map.put(key, result);
}
最后该验证了
double[] xdata = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
double[] ydata = {5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05};
RegressionDecisionTree dt = new RegressionDecisionTree ();
dt.train(xdata, ydata, 3);
double[] nxdata = {8.6, 5.5 , 4.4, 3.4, 2.4};
double[] predictRT = dt.predict(nxdata);
System.out.println("nxdata:" + Arrays.toString(nxdata));
System.out.println("预测结果predictRT :" + Arrays.toString(predictRT ));
------------------------------------------------------------
最终结果:
nxdata: 8.6 5.5 4.4 3.4 2.4
预测结果predictRT: 8.7 6.6 6.6 5.91 5.7
总结,最小二乘法回归树,将输入数据(xdata)分成多层树,也就是将输入数据(xdata)分成多个连续的区间,然后,再把输出数据(ydata),根据输入数据(xdata)区间对应的索引进行分组并计算均值。这样每个输入数据(xdata)区间都有对应的 ydata均值, 预测时候,判断新的输入数据落入那个区间,就返回对应ydata均值。