决策树算法介绍:原理与案例实现

在 Java 中,决策树算法是一种常用的机器学习算法,用于分类和预测任务。

原理
决策树通过对数据进行一系列基于特征的划分,构建一个类似于树的结构。每个内部节点表示对一个特征的测试,分支代表测试的结果,叶节点则对应最终的分类类别或预测值。

选择划分特征的依据通常是信息增益、基尼系数等指标,目的是使划分后的子节点纯度更高,即尽可能只包含同一类别的样本。

以下是一个使用 Java 实现简单决策树算法的案例:

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

class DecisionTreeNode {
    int featureIndex;
    double threshold;
    DecisionTreeNode left;
    DecisionTreeNode right;
    int classLabel;

    public DecisionTreeNode(int featureIndex, double threshold, DecisionTreeNode left, DecisionTreeNode right, int classLabel) {
        this.featureIndex = featureIndex;
        this.threshold = threshold;
        this.left = left;
        this.right = right;
        this.classLabel = classLabel;
    }
}

class DecisionTree {
    public DecisionTreeNode buildTree(List<List<Double>> data, List<Integer> labels) {
        // 计算信息增益,选择最佳特征和阈值进行划分
        int bestFeature = findBestFeature(data, labels);
        double bestThreshold = findBestThreshold(data, bestFeature, labels);

        List<List<Double>> leftData = new ArrayList<>();
        List<Integer> leftLabels = new ArrayList<>();
        List<List<Double>> rightData = new ArrayList<>();
        List<Integer> rightLabels = new ArrayList<>();

        for (int i = 0; i < data.size(); i++) {
            if (data.get(i).get(bestFeature) <= bestThreshold) {
                leftData.add(data.get(i));
                leftLabels.add(labels.get(i));
            } else {
                rightData.add(data.get(i));
                rightLabels.add(labels.get(i));
            }
        }

        if (leftData.isEmpty() || rightData.isEmpty()) {
            // 如果某个子节点数据为空,直接返回多数类标签
            Map<Integer, Integer> labelCount = new HashMap<>();
            for (Integer label : labels) {
                if (labelCount.containsKey(label)) {
                    labelCount.put(label, labelCount.get(label) + 1);
                } else {
                    labelCount.put(label, 1);
                }
            }

            int maxCount = 0;
            int majorityLabel = -1;
            for (Map.Entry<Integer, Integer> entry : labelCount.entrySet()) {
                if (entry.getValue() > maxCount) {
                    maxCount = entry.getValue();
                    majorityLabel = entry.getKey();
                }
            }

            return new DecisionTreeNode(-1, -1, null, null, majorityLabel);
        }

        DecisionTreeNode leftNode = buildTree(leftData, leftLabels);
        DecisionTreeNode rightNode = buildTree(rightData, rightLabels);

        return new DecisionTreeNode(bestFeature, bestThreshold, leftNode, rightNode, -1);
    }

    private int findBestFeature(List<List<Double>> data, List<Integer> labels) {
        // 计算每个特征的信息增益,选择信息增益最大的特征
        int numFeatures = data.get(0).size();
        double maxInfoGain = Double.MIN_VALUE;
        int bestFeature = -1;

        for (int feature = 0; feature < numFeatures; feature++) {
            double infoGain = calculateInfoGain(data, feature, labels);
            if (infoGain > maxInfoGain) {
                maxInfoGain = infoGain;
                bestFeature = feature;
            }
        }

        return bestFeature;
    }

    private double calculateInfoGain(List<List<Double>> data, int feature, List<Integer> labels) {
        // 计算信息增益的具体实现
        //...
    }

    private double findBestThreshold(List<List<Double>> data, int feature, List<Integer> labels) {
        // 找到最佳划分阈值的具体实现
        //...
    }

    public int predict(DecisionTreeNode root, List<Double> sample) {
        if (root.left == null && root.right == null) {
            return root.classLabel;
        }

        if (sample.get(root.featureIndex) <= root.threshold) {
            return predict(root.left, sample);
        } else {
            return predict(root.right, sample);
        }
    }

    public static void main(String[] args) {
        // 示例数据
        List<List<Double>> data = new ArrayList<>();
        List<Integer> labels = new ArrayList<>();

        // 构建决策树
        DecisionTree dt = new DecisionTree();
        DecisionTreeNode root = dt.buildTree(data, labels);

        // 进行预测
        List<Double> newSample = new ArrayList<>();
        int predictedLabel = dt.predict(root, newSample);
        System.out.println("预测标签: " + predictedLabel);
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

令人着迷

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值