CART(Classification and Regression Trees)是一种决策树算法,用于分类和回归问题。下面是CART算法的Java实现。 首先,我们需要定义一个节点类,用于表示CART决策树中的每个节点: ``` public class Node { private int featureIndex; // 特征索引 private double threshold; // 阈值 private double value; // 结果值(叶子节点) private Node left; // 左节点 private Node right; // 右节点 // 构造函数 public Node(int featureIndex, double threshold, double value, Node left, Node right) { this.featureIndex = featureIndex; this.threshold = threshold; this.value = value; this.left = left; this.right = right; } // 判断是否为叶子节点 public boolean isLeaf() { return left == null && right == null; } // Getters and setters public int getFeatureIndex() { return featureIndex; } public void setFeatureIndex(int featureIndex) { this.featureIndex = featureIndex; } public double getThreshold() { return threshold; } public void setThreshold(double threshold) { this.threshold = threshold; } public double getValue() { return value; } public void setValue(double value) { this.value = value; } public Node getLeft() { return left; } public void setLeft(Node left) { this.left = left; } public Node getRight() { return right; } public void setRight(Node right) { this.right = right; } } ``` 然后,我们需要定义一个CART类,用于训练和预测: ``` import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class CART { private Node root; // CART树的根节点 private int maxDepth; // 最大深度 private int minSamplesSplit; // 最小样本数 // 构造函数 public CART(int maxDepth, int minSamplesSplit) { this.maxDepth = maxDepth; this.minSamplesSplit = minSamplesSplit; } // 训练函数 public void fit(double[][] X, double[] y) { root = buildTree(X, y, 0); } // 预测函数 public double predict(double[] x) { Node node = root; while (!node.isLeaf()) { if (x[node.getFeatureIndex()] <= node.getThreshold()) { node = node.getLeft(); } else { node = node.getRight(); } } return node.getValue(); } // 构建决策树 private Node buildTree(double[][] X, double[] y, int depth) { int nSamples = X.length; int nFeatures = X[0].length; // 如果样本数小于最小样本数或者达到最大深度,则返回叶子节点 if (nSamples < minSamplesSplit || depth == maxDepth) { return new Node(-1, -1, mean(y), null, null); } double impurity = impurity(y); double bestImpurity = Double.POSITIVE_INFINITY; int bestFeatureIndex = 0; double bestThreshold = 0; // 寻找最佳划分特征和阈值 for (int i = 0; i < nFeatures; i++) { double[] featureValues = new double[nSamples]; for (int j = 0; j < nSamples; j++) { featureValues[j] = X[j][i]; } Arrays.sort(featureValues); for (int j = 0; j < nSamples - 1; j++) { double threshold = (featureValues[j] + featureValues[j + 1]) / 2; List<double[]> splits = split(X, y, i, threshold); double leftImpurity = impurity(splits.get(0)); double rightImpurity = impurity(splits.get(1)); double impurityReduction = impurity - (splits.get(0).length * leftImpurity + splits.get(1).length * rightImpurity) / nSamples; if (impurityReduction < bestImpurity) { bestImpurity = impurityReduction; bestFeatureIndex = i; bestThreshold = threshold; } } } // 如果无法继续降低不纯度,则返回叶子节点 if (bestImpurity == Double.POSITIVE_INFINITY) { return new Node(-1, -1, mean(y), null, null); } // 划分数据集 List<double[]> leftX = new ArrayList<>(); List<double[]> rightX = new ArrayList<>(); List<Double> leftY = new ArrayList<>(); List<Double> rightY = new ArrayList<>(); for (int i = 0; i < nSamples; i++) { if (X[i][bestFeatureIndex] <= bestThreshold) { leftX.add(X[i]); leftY.add(y[i]); } else { rightX.add(X[i]); rightY.add(y[i]); } } Node left = buildTree(listToArray(leftX), listToArray(leftY), depth + 1); Node right = buildTree(listToArray(rightX), listToArray(rightY), depth + 1); return new Node(bestFeatureIndex, bestThreshold, -1, left, right); } // 计算不纯度 private double impurity(double[] y) { double n = y.length; double count1 = 0, count2 = 0; for (int i = 0; i < n; i++) { if (y[i] == 1) { count1++; } else { count2++; } } double p1 = count1 / n; double p2 = count2 / n; return 1 - p1 * p1 - p2 * p2; } // 划分数据集 private List<double[]> split(double[][] X, double[] y, int featureIndex, double threshold) { List<double[]> leftX = new ArrayList<>(); List<double[]> rightX = new ArrayList<>(); List<Double> leftY = new ArrayList<>(); List<Double> rightY = new ArrayList<>(); for (int i = 0; i < X.length; i++) { if (X[i][featureIndex] <= threshold) { leftX.add(X[i]); leftY.add(y[i]); } else { rightX.add(X[i]); rightY.add(y[i]); } } List<double[]> splits = new ArrayList<>(); splits.add(listToArray(leftX)); splits.add(listToArray(rightX)); return splits; } // 计算均值 private double mean(double[] y) { double sum = 0; for (double value : y) { sum += value; } return sum / y.length; } // List转数组 private double[][] listToArray(List<double[]> list) { int m = list.size(); int n = list.get(0).length; double[][] array = new double[m][n]; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { array[i][j] = list.get(i)[j]; } } return array; } // Getters and setters public Node getRoot() { return root; } public void setRoot(Node root) { this.root = root; } public int getMaxDepth() { return maxDepth; } public void setMaxDepth(int maxDepth) { this.maxDepth = maxDepth; } public int getMinSamplesSplit() { return minSamplesSplit; } public void setMinSamplesSplit(int minSamplesSplit) { this.minSamplesSplit = minSamplesSplit; } } ``` 最后,我们可以使用以下代码进行训练和预测: ``` public class Main { public static void main(String[] args) { double[][] X = {{2.0, 4.0}, {3.0, 6.0}, {4.0, 8.0}, {5.0, 10.0}, {6.0, 12.0}}; double[] y = {1, 1, 1, 2, 2}; CART cart = new CART(2, 2);, y); double[] x1 = {3.5, 7.0}; double[] x2 = {5.5, 11.0}; double y1 = cart.predict(x1); double y2 = cart.predict(x2); System.out.println(y1); // 输出1.0 System.out.println(y2); // 输出2.0 } } ``` 上述代码使用CART算法训练了一个决策树,并使用该决策树预测了两个样本的类别。


