Java程序员学算法(1) - 最小二乘回归树(Least Squares Regression Tree)

       在算法如此时髦的当前,作为一名工程方面的Java程序员,同时也是算法小白,有着追赶时髦的好奇心,尝试着向算法的大门张望一下,因此,在此写一下在门缝中看到的东西。先从能看懂的CART的最小二乘回归树开始,并使用擅长的Java来实现。

       首先,进入名词解释,CART( Classification And Regression Tree)和最小二乘法,它们的具体解释还请自己搜索吧。

最小二乘回归树:简单来说就是基于最小二乘的回归数。详细解释可参考 https://www.jianshu.com/p/005a4e6ac775 的第一部分。

       然后进入代码阶段:

       1、创建基础的对象,二叉树的节点- BinaryTreeNode ,其有五个属性

// 拆分点
double splitPoint;
// 拆分点左边的数据的均值,包含拆分点
double leftValue; // include the splitValue
// 拆分点左边元素的索引
int[] leftIdx;
// 拆分点左边的数据的均值
double rightValue;
// 拆分点右边元素的索引
int[] rightIdx;

        2、开始实现逻辑

              创建回归数的节点,如下为参考url的:算法5.5里面的公式5.21

/**
 * 依据给定的X和Y数据,基于最小二乘回归树生成 1 个二叉树(1个节点)
 * 选择最优的切分点
 * @param xdata
 * @param ydata
 * @return
 */
BinaryTreeNode generateRegressTreeNode(double[] xdata, double[] ydata) {
    BinaryTreeNode brn = null;
    int dataLength = xdata.length;
    double minSum = 0;
    // 遍历输入值,将xdata分为2个部分
    for (int i = 0; i < dataLength; i++) {
        // X数据的每一个值都可作为切分点
        double splitPoint = xdata[i];
            
        int[] r1Idx = new int[dataLength];
        int[] r2Idx = new int[dataLength];
            
        for (int j = 0; j < dataLength; j++) {
            r1Idx[j] = -1;
            r2Idx[j] = -1;
            if (xdata[j] > splitPoint) {
                r2Idx[j] = j;
            } else {
                r1Idx[j] = j;
            }
        }
        // 切分点左侧Y的数据均值
        double c1 = meanMatrix1DByIdx(ydata, r1Idx);
        // 切分点右侧Y的数据均值
        double c2 = meanMatrix1DByIdx(ydata, r2Idx);
            
        // 左侧和右侧值的差平方和
        double sumsl = sumSquareLoss(ydata, c1, r1Idx, c2, r2Idx);
            
        // 找最小的和(冒泡方式)
        if (i == 0 ||  minSum > sumsl) {
            minSum = sumsl;
            brn = new BinaryTreeNode(splitPoint, c1, c2);
            brn.setLeftIdx(r1Idx);
            brn.setRightIdx(r2Idx);
        }
    }
    return brn;
}
// 左侧和右侧值的差平方和
double sumSquareLoss(double[] data, double c1, int[] r1IndexArray, double c2, int[] r2IndexArray) {
    double sum = 0.0d;
    for (int idx : r1IndexArray) {
        if (idx > -1) {
            sum = sum + squareLoss(data[idx], c1);
        }
    }
        
    for (int idx : r2IndexArray) {
        if (idx > -1) {
            sum = sum + squareLoss(data[idx], c2);
        }
    }
        
    return sum;
}

double meanMatrix1DByIdx(double[] data, int[] indexArray) {
    double ret = 0.0d;
    int count = 0;
    for (int idx : indexArray) {
        if (idx > -1) {
            ret = ret + data[idx];
            count++;
        }
    }
    ret = ret / count;
    return ret;
}

// 差平方
double squareLoss(double y, double predict){
        
    double loss = y - predict;
    return loss * loss;
}

        创建回归树

/**
 * 生成回归树
 * 递归方式
 *
 */
void generateRegressTree(double[] xdata, double[] ydata, int currentLayer, int maxLayer){
        
    if (currentLayer > maxLayer) {
        return;
    }
    BinaryTreeNode btn = generateRegressTreeNode(xdata, ydata);
    // 成员变量,记录所有的节点
    binaryTreeNodeList.add(btn);
        
    currentLayer++;

    int divideLimit = 2; // 至少保证2个及以上的数据才能做分叉,这个决定节点的最小粒度
    int leftCount = getValidIndexCount(btn.getLeftIdx());
    if (leftCount >= divideLimit) {
        double[] leftXData = getDataByIndex(xdata, leftCount, btn.getLeftIdx());
        double[] leftYData = getDataByIndex(ydata, leftCount, btn.getLeftIdx());
        generateRegressTree(leftXData, leftYData, currentLayer, maxLayer);
    }
        
    int rightCount = getValidIndexCount(btn.getRightIdx());
    if (rightCount >= divideLimit) {
        double[] rightXData = getDataByIndex(xdata, rightCount, btn.getRightIdx()); 
        double[] rightYData = getDataByIndex(ydata, rightCount, btn.getRightIdx());
        generateRegressTree(rightXData, rightYData, currentLayer, maxLayer);
    }
}
    
double[] getDataByIndex(double[] data, int resultLength, int[] index) {
    double[] ret = new double[resultLength];
    int i = 0;
    for (int idx : index) {
        if (idx < 0) {
            continue;
        }
        ret[i] = data[idx];
        i++;
    }
    return ret;
}
int getValidIndexCount(int[] index) {
    int ret = 0;
    for (int idx : index) {
        if (idx> -1) {
            ret++;
        }
    }
    return ret;
}

   预测和训练逻辑,执行顺序是先训练再做预测

List<BinaryTreeNode> binaryTreeNodeList; // 所有节点的列表
double[] resultData; // 训练的结果,所有二叉树的
double[] splitPointArray;
    
public RegressionDecisionTree(){
    binaryTreeNodeList = new ArrayList<BinaryTreeNode>();
}
    
   
/**
 * 在训练的结果之上,依据输入的数据得出预测值
 * @param newXData
 * @return
 */
double[] predict(double[] newXData) {
    double[] ret = new double[newXData.length];
       
    for (int i = 0; i < newXData.length; i++) {
        boolean valid = false;
        // 寻找小于等于切分点和赋值对应的值
        for (int j = 0; j < splitPointArray.length; j++) {
            if (newXData[i] <= splitPointArray[j]) {
                ret[i] = resultData[j];
                valid = true;
                break;
            }
        }
        // 没找到的,给最大的。
        if (valid == false) {
            ret[i] = resultData[resultData.length - 1];
        }
    }
    return ret;
}
    
/**
 * 基于回归树,根据深度和给定的训练数据,得到多级的二叉树
 * 详见参考的 算法5.5的 第(2)
 * @param xdata
 * @param ydata
 * @param deep
 */
void train(double[] xdata, double[] ydata, int deep) {
        
    generateRegressTree(xdata, ydata, 1, deep);
        
    // calculate result average
    splitPointArray = new double[binaryTreeNodeList.size()];
    for (int j = 0; j < binaryTreeNodeList.size(); j++) {
        BinaryTreeNode btn = binaryTreeNodeList.get(j);
        splitPointArray[j] = btn.getSplitPoint();
    }
    // 将切分点值升序
    Arrays.sort(splitPointArray);
    // 每个切分点对一个值,并且依据从最小对齐原则,要给超过切分点最大值的位置,因此ResultData的length比切分点多一个,也就是最大的切分点会对应2个值
    resultData = new double[splitPointArray.length + 1];
    Map<Double, Double> pointAndTotal = new HashMap<Double, Double>();
    Map<Double, Integer> pointAndCount = new HashMap<Double, Integer>();

    int ext = 1;
    for (int i = 0; i < xdata.length; i++) {
        // 记录X小于 最小切分点值 的对应Y的和,以及数量,Key为最小切分点值
        if (xdata[i] < splitPointArray[0]) {
                
            assembleMapDoubleValue(pointAndTotal, splitPointArray[0], ydata[i]);
            assembleMapIntegerValue(pointAndCount, splitPointArray[0]);
            continue;
        } 
        // 记录X大于等于 最大切分点值 的对应Y的和,以及数量,
        // Key为最大切分点值+1 ,因为此处是额外多出来的。
        if (xdata[i] >= splitPointArray[splitPointArray.length - 1]) {               
            assembleMapDoubleValue(pointAndTotal, splitPointArray[splitPointArray.length - 1] + ext, ydata[i]);
            assembleMapIntegerValue(pointAndCount, splitPointArray[splitPointArray.length - 1] + ext);
            continue;
        }
        // 记录X大于等于最小切分点值和X小于最大切分点值的对应Y的和,Key为切分点值
        for (int j = 0; j < splitPointArray.length - 1; j++) {
               
            if (xdata[i] >= splitPointArray[j] && xdata[i] < splitPointArray[j + 1]) {
                assembleMapDoubleValue(pointAndTotal, splitPointArray[j + 1], ydata[i]);
                assembleMapIntegerValue(pointAndCount, splitPointArray[j + 1]);
            }
        }
    }
    
    // 计算各个切分点值对应Y值的均值。
    for (int i = 0; i < splitPointArray.length; i++) {
         resultData[i] = pointAndTotal.get(splitPointArray[i]) / pointAndCount.get(splitPointArray[i]);
    }
    // 计算超过最大 切分点值的Y的均值。
    resultData[resultData.length - 1] = pointAndTotal.get(splitPointArray[splitPointArray.length - 1] + ext) 
                / pointAndCount.get(splitPointArray[splitPointArray.length - 1] + ext);
}
    
void assembleMapIntegerValue(Map<Double, Integer> map, double key) {
    int result = 1;
    if (map.containsKey(key)) {
        result = map.get(key).intValue() + 1;
    }
    map.put(key, result);
}
    
void assembleMapDoubleValue(Map<Double, Double> map, double key, double value) {
    double result = value;
    if (map.containsKey(key)) {
        result = map.get(key).doubleValue() + result;
    }
    map.put(key, result);
}

最后该验证了

double[] xdata = {1,    2,    3,    4,    5,    6,    7,    8,    9,    10};
double[] ydata = {5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05};
RegressionDecisionTree dt = new RegressionDecisionTree ();
        
dt.train(xdata, ydata, 3);

double[] nxdata = {8.6, 5.5 , 4.4, 3.4, 2.4};
double[] predictRT = dt.predict(nxdata);

System.out.println("nxdata:" + Arrays.toString(nxdata));
System.out.println("预测结果predictRT :" + Arrays.toString(predictRT ));

------------------------------------------------------------
最终结果:
nxdata: 8.6  5.5  4.4  3.4  2.4  

预测结果predictRT: 8.7  6.6  6.6  5.91  5.7  

总结,最小二乘法回归树,将输入数据(xdata)分成多层树,也就是将输入数据(xdata)分成多个连续的区间,然后,再把输出数据(ydata),根据输入数据(xdata)区间对应的索引进行分组并计算均值。这样每个输入数据(xdata)区间都有对应的 ydata均值, 预测时候,判断新的输入数据落入那个区间,就返回对应ydata均值。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
复数偏最小二乘回归算法(Partial Least Squares Regression,PLSR)是一种用于建立输入变量和输出变量之间线性关系的回归分析方法。与普通的最小二乘回归相比,PLSR可用于处理多变量共线性(multicollinearity)问题,即当输入变量之间存在高度相关性时。 PLSR的原理是将输入变量和输出变量分别投影到一个低维的空间中,使得在该空间中的投影值最大程度地保留原始数据的信息。具体来说,PLSR通过寻找一组正交的投影向量,将输入变量和输出变量分别投影到该向量空间中,从而得到一组新的变量。这些新变量代表原始变量的线性组合,被称为潜在变量(latent variables)。潜在变量的数量通常小于原始变量的数量,因此,通过PLSR可以实现对数据的降维处理。 PLSR的关键是选择合适的投影向量。PLSR采用交替最小二乘法(alternating least squares,ALS)来计算投影向量。该方法先选择一个初始的投影向量,然后对输入变量和输出变量进行投影,得到新的潜在变量。接着,将新的潜在变量作为输入变量,再次进行投影,得到更新后的投影向量。该过程迭代执行,直到收敛或达到预设的迭代次数。 PLSR适用于多元统计分析、数据挖掘、化分析、生物医工程等领域。它可以用于建立输入变量和输出变量之间的线性关系模型,同时对数据进行降维处理,提高模型的解释性和预测性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值