CART 决策树算法JAVA

1、一个枚举类,用来指示特征是数值型的还是字符型的

/**
 * 一个枚举类,用来指示特征是数值型的还是字符型的
 * @date 2023/10/20 15:26
 * @author luohao
 */
public enum DataType {
    // 字符串
    String,
    // 数字
    Number;
}

2、预测结果对象

/**
 * 预测结果对象
 * @date 2023/10/20 15:27
 * @author luohao
 */
public class PredictResult {
    String[] labelArr;
    String predictLabel;
    double[] predictArr;

    public PredictResult(String[] labelArr, String predictLabel, double[] predictArr) {
        this.labelArr = labelArr;
        this.predictLabel = predictLabel;
        this.predictArr = predictArr;
    }

    @Override
    public String toString() {
        return "PredictResult{" +
                "predictLabel='" + predictLabel + '\'' +
                ", predictArr=" + CartDecisionTree.predictArrToString(predictArr, labelArr) +
                '}';
    }
}

3、训练数据集存放对象

import java.util.ArrayList;
import java.util.List;

/**
 * 训练数据集存放对象
 * @date 2023/10/20 15:13
 * @author luohao
 */
public class TrainDataSet {
    /**
     * 特征集合
     **/
    public List<Object[]> features = new ArrayList<>();
    /**
     * 数据类型数组
     **/
    DataType[] dataTypes;
    /**
     * 标签集合
     **/
    public List<String> labels = new ArrayList<>();
    /**
     * 特征向量维度
     **/
    public int featureDim;

    public TrainDataSet(DataType[] dataTypes) {
        this.dataTypes = dataTypes;
        this.featureDim = dataTypes.length;
    }

    public int size() {
        return labels.size();
    }

    public Object[] getFeature(int index) {
        return features.get(index);
    }

    public String getLabel(int index) {
        return labels.get(index);
    }

    public void addData(Object[] feature, String label) {
        if (featureDim != feature.length) {
            throwDimensionMismatchException(feature.length);
        }
        features.add(feature);
        labels.add(label);
    }

    public void throwDimensionMismatchException(int errorLen) {
        throw new RuntimeException("DimensionMismatchError: 你应该传入维度为 " + featureDim + " 的特征向量 , 但你传入了维度为 " + errorLen + " 的特征向量");
    }
}

4、CART 决策树算法对象

import java.util.*;

/**
 * CART 决策树算法对象
 * @date 2023/10/20 15:28
 * @author luohao
 */
public class CartDecisionTree {
    /**
     * 训练的数据集
     **/
    TrainDataSet trainDataSet;
    /**
     * 所有分类类型集合
     **/
    String[] labelArr;
    /**
     * 限制树的深度
     **/
    Integer maxDeep;
    /**
     * 限制叶子节点的个数
     **/
    Integer maxLeafNum;
    /**
     * 限制每个节点的样本数
     **/
    Integer minDataSize;
    /**
     * 决策树
     **/
    DecisionTree root;


    public CartDecisionTree(TrainDataSet trainDataSet, Integer maxDeep, Integer maxLeafNum, Integer minDataSize) {
        this.trainDataSet = trainDataSet;
        this.maxDeep = maxDeep;
        this.maxLeafNum = maxLeafNum;
        this.minDataSize = minDataSize;
        // 将 Label 去重,获取所有类别
        HashSet<String> labelSet = new HashSet<>(trainDataSet.labels);
        this.labelArr = new String[labelSet.size()];
        int i = 0;
        for (String label : labelSet) {
            this.labelArr[i++] = label;
        }
    }

    public void initModel() {
        root = createLeafNode(1, trainDataSet.features, trainDataSet.labels, new boolean[trainDataSet.featureDim]);
    }

    // 传入特征向量,返回预测值
    public PredictResult predict(Object[] features) {
        DecisionTree tree = root;
        while (tree.condition != null) {
            int featureIndex = tree.condition.featureIndex;
            if (trainDataSet.dataTypes[featureIndex].equals(DataType.String)) {
                // 字符类型的分支下走
                List<String> stringList = (List<String>) tree.condition.conditionValue;
                for (int i = 0; i < stringList.size(); i++) {
                    if (stringList.get(i).equals((String) features[featureIndex])) {
                        tree = tree.children.get(i);
                        break;
                    }
                }
            } else {
                // 数字类型的分支下走:左子树为 >= ,右子树为 <
                if ((double) features[featureIndex] >= (double) tree.condition.conditionValue) {
                    tree = tree.children.get(0);
                } else {
                    tree = tree.children.get(1);
                }
            }
        }
        return new PredictResult(labelArr, tree.predictLabel, tree.predictArr);
    }

    public void fit() {
        initModel();
        // BFS 构建决策树
        Queue<DecisionTree> queue = new LinkedList<>();
        queue.add(root);
        while (!queue.isEmpty()) {
            // 从队列中取出当前要构建的不完整的决策树(只有树深度、数据和最可能的标签,还没有分支条件和子节点)
            DecisionTree decisionTree = queue.poll();
            // 对特征进行遍历,找到分支后GINI系数最小的特征进行分支
            BranchResult bestBranchResult = null;
            // 预剪枝:限制树的深度
            if (maxDeep == null || decisionTree.deep + 1 <= maxDeep) {
                for (int featureIndex = 0; featureIndex < trainDataSet.featureDim; featureIndex++) {
                    // 只对没被禁忌的特征进行计算
                    if (!decisionTree.featureTabuArr[featureIndex]) {
                        BranchResult branchResult = null;
                        if (trainDataSet.dataTypes[featureIndex].equals(DataType.String)) {
                            // 对字符串类型的特征进行分支
                            branchResult = stringBranch(featureIndex, decisionTree);
                        } else {
                            // 对数字类型的特征进行分支
                            branchResult = numberBranch(featureIndex, decisionTree);
                        }
                        if (branchResult != null) {
                            if (bestBranchResult == null || branchResult.gini < bestBranchResult.gini) {
                                bestBranchResult = branchResult;
                            }
                        }
                    }
                }
            }
            // 将最佳分支结果中的节点加入队列,并将其加入当前节点的子节点集合
            if (bestBranchResult != null) {
                decisionTree.children.addAll(bestBranchResult.decisionTreeList);
                decisionTree.condition = bestBranchResult.condition;
                for (DecisionTree child : decisionTree.children) {
                    // 如果最佳分支是字符串型分支,那么可以直接禁忌,之后不用再对那个特征进行分支
                    if (bestBranchResult.condition.dataType.equals(DataType.String)) {
                        child.featureTabuArr = decisionTree.featureTabuArr.clone();
                        child.featureTabuArr[bestBranchResult.condition.featureIndex] = true;
                    } else {
                        child.featureTabuArr = decisionTree.featureTabuArr.clone();
                    }
                    child.deep = decisionTree.deep + 1;
                }
                queue.addAll(decisionTree.children);
            }
            // 预剪枝:限制叶子节点的个数
            if (maxLeafNum != null && queue.size() >= maxLeafNum) {
                break;
            }
        }
        root.print();
    }

    // 对数字类型的特征进行分支(对每两个数字中间的值进行分支)
    private BranchResult numberBranch(int featureIndex, DecisionTree decisionTree) {
        BranchResult bestBranchResult = null;
        // 记录已经计算过的二分实数(用字符串是为了避免浮点型变量精度带来的 hash 失效,字符串是数字保留 6 位有效数字的结果),还有一个辅助作用,就是对数值特征值进行去重
        HashSet<String> valueSet = new HashSet<>();
        // 首先获取所有数值
        List<Double> valueList = new ArrayList<>();
        for (int i = 0; i < decisionTree.features.size(); i++) {
            double v = (double) decisionTree.features.get(i)[featureIndex];
            if (valueSet.add(String.format("%.6f", v))) {
                valueList.add(v);
            }
        }
        // 然后排序
        Collections.sort(valueList);
        // 然后就选取中间值进行分支尝试
        valueSet = new HashSet<>();
        for (int i = 0; i < valueList.size() - 1; i++) {
            double mid = (valueList.get(i) + valueList.get(i + 1)) / 2.0;
            if (valueSet.add(String.format("%.6f", mid))) {
                // 辅助计算 GINI 系数
                Map<String, Integer> leftMap = new HashMap<>();
                Map<String, Integer> rightMap = new HashMap<>();
                // 初始化左右节点 左子树为 >= ,右子树为 <
                DecisionTree left = new DecisionTree();
                DecisionTree right = new DecisionTree();
                // 向左右节点加入数据
                for (int j = 0; j < decisionTree.features.size(); j++) {
                    if ((double) decisionTree.features.get(j)[featureIndex] >= mid) {
                        left.features.add(decisionTree.features.get(j));
                        left.labels.add(decisionTree.labels.get(j));
                        Integer cnt = leftMap.getOrDefault(decisionTree.labels.get(j), null);
                        leftMap.put(decisionTree.labels.get(j), cnt == null ? 1 : cnt + 1);
                    } else {
                        right.features.add(decisionTree.features.get(j));
                        right.labels.add(decisionTree.labels.get(j));
                        Integer cnt = rightMap.getOrDefault(decisionTree.labels.get(j), null);
                        rightMap.put(decisionTree.labels.get(j), cnt == null ? 1 : cnt + 1);
                    }
                }
                // 预剪枝:限制每个节点的样本数
                if (minDataSize != null) {
                    if (left.labels.size() < minDataSize || right.labels.size() < minDataSize) {
                        continue;
                    }
                }
                // 计算 GINI 系数
                double leftGINI = (double) left.labels.size() / decisionTree.labels.size() * GINI(left.labels.size(), leftMap);
                double rightGINI = (double) right.labels.size() / decisionTree.labels.size() * GINI(right.labels.size(), rightMap);
                double gini = leftGINI + rightGINI;
                if (bestBranchResult == null || gini < bestBranchResult.gini) {
                    bestBranchResult = new BranchResult();
                    bestBranchResult.gini = gini;
                    // 左子树为 >= ,右子树为 <
                    calcPredictLabelAndArr(left, leftMap);
                    calcPredictLabelAndArr(right, rightMap);
                    bestBranchResult.decisionTreeList.add(left);
                    bestBranchResult.decisionTreeList.add(right);
                    bestBranchResult.condition = new Condition(DataType.Number, featureIndex, mid);
                }
            }
        }
        return bestBranchResult;
    }

    // 对字符串类型的特征进行分支
    private BranchResult stringBranch(int featureIndex, DecisionTree decisionTree) {
        // 开始根据指定特征进行分组
        Map<String, DecisionTree> decisionTreeListMap = new HashMap<>();
        // 辅助 GINI 系数计算的 Map ,存储当前 feature 的不同取值下,label 的不同取值的个数
        Map<String, Map<String, Integer>> giniCalcMap = new HashMap<>();
        for (int i = 0; i < decisionTree.features.size(); i++) {
            String key = (String) (decisionTree.features.get(i)[featureIndex]);
            if (!decisionTreeListMap.containsKey(key)) {
                decisionTreeListMap.put(key, new DecisionTree());
                giniCalcMap.put(key, new HashMap<>());
            }
            decisionTreeListMap.get(key).features.add(decisionTree.features.get(i));
            decisionTreeListMap.get(key).labels.add(decisionTree.labels.get(i));
            Integer cnt = giniCalcMap.get(key).getOrDefault(decisionTree.labels.get(i), null);
            giniCalcMap.get(key).put(decisionTree.labels.get(i), cnt == null ? 1 : cnt + 1);
        }
        // 如果 decisionTreeListMap 的  size 为 1,说明当前节点当前特征已经纯了,那么就不用对这个特征进行分支了,所以可以直接返回 null
        if (decisionTreeListMap.size() <= 1) {
            return null;
        }
        // 预剪枝:限制每个节点的样本数
        if (minDataSize != null) {
            for (String key : decisionTreeListMap.keySet()) {
                if (decisionTreeListMap.get(key).labels.size() < minDataSize) {
                    return null;
                }
            }
        }
        // 计算GINI系数,并生成 BranchResult
        List<DecisionTree> decisionTreeList = new ArrayList<>();
        List<String> conditionValue = new ArrayList<>();
        double gini = 0d;
        for (String key : decisionTreeListMap.keySet()) {
            DecisionTree tree = decisionTreeListMap.get(key);
            calcPredictLabelAndArr(tree, giniCalcMap.get(key));
            decisionTreeList.add(tree);
            conditionValue.add(key);
            // 计算 GINI 系数
            double rate = ((double) tree.labels.size() / decisionTree.labels.size());
            gini += (rate * GINI(tree.labels.size(), giniCalcMap.get(key)));
        }
        BranchResult branchResult = new BranchResult();
        branchResult.gini = gini;
        branchResult.decisionTreeList = decisionTreeList;
        branchResult.condition = new Condition(DataType.String, featureIndex, conditionValue);
        return branchResult;
    }

    // 根据一个特征取值的 GINI 系数
    private double GINI(int totalCnt, Map<String, Integer> map) {
        double gini = 1d;
        for (String key : map.keySet()) {
            gini -= Math.pow(((double) map.get(key) / totalCnt), 2);
        }
        return gini;
    }

    // 获取一个叶子节点
    private DecisionTree createLeafNode(int deep, List<Object[]> features, List<String> labels, boolean[] featureTabuArr) {
        DecisionTree leaf = new DecisionTree();
        leaf.features = features;
        leaf.labels = labels;
        leaf.deep = deep;
        leaf.featureTabuArr = featureTabuArr;
        calcPredictLabelAndArr(leaf, labels);
        return leaf;
    }

    // 获取最多的标签
    private void calcPredictLabelAndArr(DecisionTree tree, List<String> labels) {
        Map<String, Integer> map = new HashMap<>();
        String mostLabel = null;
        int mostNum = -1;
        for (String label : labels) {
            Integer num = map.getOrDefault(label, null);
            map.put(label, num == null ? 1 : num + 1);
            if (map.get(label) > mostNum) {
                mostNum = map.get(label);
                mostLabel = label;
            }
        }
        if (mostNum == -1) {
            throw new RuntimeException("没找到最多的标签");
        }
        int totalCnt = 0;
        for (String label : map.keySet()) {
            totalCnt += map.get(label);
        }
        double[] predictArr = new double[labelArr.length];
        for (int i = 0; i < labelArr.length; i++) {
            predictArr[i] = (double) map.getOrDefault(labelArr[i], 0) / totalCnt;
        }
        tree.predictLabel = mostLabel;
        tree.predictArr = predictArr;
    }

    private void calcPredictLabelAndArr(DecisionTree tree, Map<String, Integer> map) {
        String mostLabel = null;
        int mostNum = -1;
        int totalCnt = 0;
        for (String label : map.keySet()) {
            if (map.get(label) > mostNum) {
                mostNum = map.get(label);
                mostLabel = label;
            }
            totalCnt += map.get(label);
        }
        if (mostNum == -1) {
            throw new RuntimeException("没找到最多的标签");
        }
        double[] predictArr = new double[labelArr.length];
        for (int i = 0; i < labelArr.length; i++) {
            predictArr[i] = (double) map.getOrDefault(labelArr[i], 0) / totalCnt;
        }
        tree.predictLabel = mostLabel;
        tree.predictArr = predictArr;
    }

    // 将概率向量字符串化
    public static String predictArrToString(double[] predictArr, String[] labelArr) {
        if (predictArr.length != labelArr.length) {
            throw new RuntimeException("传入的概率矩阵维度和类型数组长度不一致: " + predictArr.length + " != " + labelArr.length);
        }
        StringBuilder str = new StringBuilder("[ ");
        for (int i = 0; i < labelArr.length - 1; i++) {
            str.append(labelArr[i]).append(":").append(String.format("%.2f", predictArr[i])).append(" , ");
        }
        str.append(labelArr[labelArr.length - 1]).append(":").append(String.format("%.2f", predictArr[predictArr.length - 1])).append(" ]");
        return str.toString();
    }

    // 分支结果
    class BranchResult {
        // GINI系数
        double gini;
        // 分支后的节点集合
        List<DecisionTree> decisionTreeList = new ArrayList<>();
        // 分支条件
        Condition condition;
    }

    // 决策树
    class DecisionTree {
        Condition condition;
        List<Object[]> features = new ArrayList<>();
        List<String> labels = new ArrayList<>();
        List<DecisionTree> children = new ArrayList<>();
        // 特征走到当前节点最可能是的 Label
        String predictLabel;
        // 预测属于每个类别的概率
        double[] predictArr;
        // 树当前的深度
        int deep;
        /**
         * 记录哪些特征不用分支
         **/
        boolean[] featureTabuArr;

        // 前序遍历输出自身信息
        public void print() {
            this.printSelf();
            for (DecisionTree child : children) {
                child.print();
            }
        }

        public void printSelf() {
            if (condition != null) {
                System.out.println("deep: " + deep + " , predictLabel: " + predictLabel + " , predictArr: " + predictArrToString(predictArr, labelArr) + " , featureIndex: " + condition.featureIndex + " , condition: " + condition.conditionValue);
            } else {
                System.out.print("deep: " + deep + " , predictLabel: " + predictLabel + " , predictArr: " + predictArrToString(predictArr, labelArr) + " , features: ");
                for (Object[] feature : features) {
                    System.out.print(Arrays.toString(feature) + " , ");
                }
                System.out.println("labels: " + labels);
            }
        }

    }

    // 分支条件
    class Condition {
        DataType dataType;
        int featureIndex;
        /**
         * 如果 dataType 为 String,则 conditionValue 为 List<String>
         * 如果 dataType 为 Number,则 conditionValue 为 double,且左子树为 >= ,右子树为 <
         **/
        Object conditionValue;

        public Condition(DataType dataType, int featureIndex, Object conditionValue) {
            this.dataType = dataType;
            this.featureIndex = featureIndex;
            this.conditionValue = conditionValue;
        }
    }
}

5、测试运行算法的类

import java.util.Arrays;

/**
 * 测试运行算法的类
 * @date 2023/10/20 15:35
 * @author luohao
 */
public class TestRun {
    public static void main(String[] args) {
        // 测试纯文本的分类
        testStringData();
    }

    public static void testStringData() {
        System.out.println("================================================================== 测试纯文本的分类 ==================================================================");
        // 构建纯文本特征的数据集
        //色泽,根蒂,敲声,纹理,脐部,触感,好瓜
        TrainDataSet trainDataSet = new TrainDataSet(new DataType[]{DataType.String, DataType.String, DataType.String, DataType.String, DataType.String, DataType.String});
        trainDataSet.addData(new Object[]{"青绿","蜷缩","浊响","清晰","凹陷","硬滑"},"是");
        trainDataSet.addData(new Object[]{"乌黑","蜷缩","沉闷","清晰","凹陷","硬滑"},"是");
        trainDataSet.addData(new Object[]{"乌黑","蜷缩","浊响","清晰","凹陷","硬滑"},"是");
        trainDataSet.addData(new Object[]{"青绿","蜷缩","沉闷","清晰","凹陷","硬滑"},"是");
        trainDataSet.addData(new Object[]{"浅白","蜷缩","浊响","清晰","凹陷","硬滑"},"是");
        trainDataSet.addData(new Object[]{"青绿","稍蜷","浊响","清晰","稍凹","软粘"},"是");
        trainDataSet.addData(new Object[]{"乌黑","稍蜷","浊响","稍糊","稍凹","软粘"},"是");
        trainDataSet.addData(new Object[]{"乌黑","稍蜷","浊响","清晰","稍凹","硬滑"},"是");
        trainDataSet.addData(new Object[]{"乌黑","稍蜷","沉闷","稍糊","稍凹","硬滑"},"否");
        trainDataSet.addData(new Object[]{"青绿","硬挺","清脆","清晰","平坦","软粘"},"否");
        trainDataSet.addData(new Object[]{"浅白","硬挺","清脆","模糊","平坦","硬滑"},"否");
        trainDataSet.addData(new Object[]{"浅白","蜷缩","浊响","模糊","平坦","软粘"},"否");
        trainDataSet.addData(new Object[]{"青绿","稍蜷","浊响","稍糊","凹陷","硬滑"},"否");
        trainDataSet.addData(new Object[]{"浅白","稍蜷","沉闷","稍糊","凹陷","硬滑"},"否");
        trainDataSet.addData(new Object[]{"乌黑","稍蜷","浊响","清晰","稍凹","软粘"},"否");
        trainDataSet.addData(new Object[]{"浅白","蜷缩","浊响","模糊","平坦","硬滑"},"否");
        trainDataSet.addData(new Object[]{"青绿","蜷缩","沉闷","稍糊","稍凹","硬滑"},"否");
        long startTime = System.currentTimeMillis();
        CartDecisionTree cartDecisionTree = new CartDecisionTree(trainDataSet, null, null, null);
        cartDecisionTree.fit();
        System.out.println("训练用时: " + (System.currentTimeMillis() - startTime) / 1000d + " s");
        System.out.println("用训练好的模型进行预测: ");
        System.out.println("TestData: " + Arrays.toString(new String[]{"浅白","蜷缩","浊响","清晰","凹陷","硬滑"}) + " : " + cartDecisionTree.predict(new Object[]{"浅白","蜷缩","浊响","清晰","凹陷","硬滑"}));
        //结果TestData: [浅白, 蜷缩, 浊响, 清晰, 凹陷, 硬滑] : PredictResult{predictLabel='是', predictArr=[ 否:0.00 , 是:1.00 ]}
    }
}

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值