决策树算法介绍:原理与案例实现

一、决策树算法概述

决策树是一种常用的分类与回归方法,适用于处理非线性关系的数据。它通过一系列的分裂操作,将数据集划分成不同的子集,从而实现数据分类或回归。决策树的优点包括易于理解和解释,处理高维数据的能力强,并且可以处理缺失数据。本文将详细介绍决策树的原理,并通过案例代码实现决策树分类。

二、决策树的原理

1. 树的结构

决策树由节点和边组成:

  • 根节点:树的起始点,代表整个数据集。
  • 内部节点:每个内部节点表示对某个属性的测试,并根据测试结果将数据集划分为不同的子集。
  • 叶节点:表示分类结果或回归值。
2. 树的生成

决策树的生成过程可以通过递归地选择最佳分裂点来构建。这通常包括以下步骤:

  1. 选择最佳属性:选择能够最大化信息增益的属性进行分裂。
  2. 分裂数据集:根据选定的属性,将数据集分裂为不同的子集。
  3. 递归构建子树:对子集递归地应用上述步骤,直到满足停止条件(如所有实例属于同一类别,属性用尽,或达到最大深度)。
3. 信息增益

信息增益是选择最佳分裂属性的常用标准。信息增益基于熵(Entropy)的概念,熵用于衡量数据集的纯度。信息增益计算公式如下:

1. 数据准备

Iris数据集包含150个样本,每个样本有四个特征(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和一个类别标签(Setosa、Versicolour、Virginica)。

2. 代码实现

示例代码:决策树分类

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

class DecisionTree {
    // 节点类
    class Node {
        String attribute;
        Map<String, Node> children = new HashMap<>();
        String label;

        Node(String attribute) {
            this.attribute = attribute;
        }

        Node(String attribute, String label) {
            this.attribute = attribute;
            this.label = label;
        }
    }

    private Node root;

    // 计算熵
    private double entropy(List<String> labels) {
        Map<String, Integer> labelCounts = new HashMap<>();
        for (String label : labels) {
            labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
        }
        double entropy = 0.0;
        for (int count : labelCounts.values()) {
            double probability = (double) count / labels.size();
            entropy -= probability * Math.log(probability) / Math.log(2);
        }
        return entropy;
    }

    // 计算信息增益
    private double informationGain(List<List<String>> data, List<String> labels, String attribute) {
        int index = data.get(0).indexOf(attribute);
        double originalEntropy = entropy(labels);
        Map<String, List<String>> subsets = new HashMap<>();
        for (int i = 0; i < data.size(); i++) {
            String key = data.get(i).get(index);
            subsets.putIfAbsent(key, new ArrayList<>());
            subsets.get(key).add(labels.get(i));
        }
        double newEntropy = 0.0;
        for (List<String> subset : subsets.values()) {
            newEntropy += ((double) subset.size() / labels.size()) * entropy(subset);
        }
        return originalEntropy - newEntropy;
    }

    // 选择最佳分裂属性
    private String bestAttribute(List<List<String>> data, List<String> labels, List<String> attributes) {
        String bestAttribute = null;
        double bestGain = Double.NEGATIVE_INFINITY;
        for (String attribute : attributes) {
            double gain = informationGain(data, labels, attribute);
            if (gain > bestGain) {
                bestGain = gain;
                bestAttribute = attribute;
            }
        }
        return bestAttribute;
    }

    // 生成决策树
    public Node buildTree(List<List<String>> data, List<String> labels, List<String> attributes) {
        if (new HashSet<>(labels).size() == 1) {
            return new Node(null, labels.get(0)); // 只有一个类别,返回叶节点
        }
        if (attributes.isEmpty()) {
            return new Node(null, mostCommonLabel(labels)); // 属性用尽,返回多数类别
        }
        String bestAttribute = bestAttribute(data, labels, attributes);
        Node node = new Node(bestAttribute);
        Map<String, List<List<String>>> subsets = new HashMap<>();
        Map<String, List<String>> subLabels = new HashMap<>();
        for (int i = 0; i < data.size(); i++) {
            String key = data.get(i).get(data.get(0).indexOf(bestAttribute));
            subsets.putIfAbsent(key, new ArrayList<>());
            subLabels.putIfAbsent(key, new ArrayList<>());
            subsets.get(key).add(data.get(i));
            subLabels.get(key).add(labels.get(i));
        }
        for (String value : subsets.keySet()) {
            List<String> subAttributes = new ArrayList<>(attributes);
            subAttributes.remove(bestAttribute);
            node.children.put(value, buildTree(subsets.get(value), subLabels.get(value), subAttributes));
        }
        return node;
    }

    // 预测
    public String predict(Node node, List<String> instance) {
        if (node.label != null) {
            return node.label;
        }
        String attribute = node.attribute;
        String value = instance.get(data.get(0).indexOf(attribute));
        Node childNode = node.children.get(value);
        if (childNode == null) {
            return "Unknown"; // 无法预测
        }
        return predict(childNode, instance);
    }

    // 返回出现最多的类别标签
    private String mostCommonLabel(List<String> labels) {
        Map<String, Integer> labelCounts = new HashMap<>();
        for (String label : labels) {
            labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
        }
        String mostCommon = null;
        int maxCount = 0;
        for (Map.Entry<String, Integer> entry : labelCounts.entrySet()) {
            if (entry.getValue() > maxCount) {
                maxCount = entry.getValue();
                mostCommon = entry.getKey();
            }
        }
        return mostCommon;
    }

    public static void main(String[] args) {
        // 示例数据集
        List<List<String>> data = Arrays.asList(
                Arrays.asList("长", "长", "女"),
                Arrays.asList("短", "短", "男"),
                Arrays.asList("长", "短", "女"),
                Arrays.asList("短", "长", "男"),
                Arrays.asList("长", "长", "女"),
                Arrays.asList("长", "短", "男")
        );
        List<String> labels = Arrays.asList("女生", "男生", "女生", "男生", "女生", "男生");
        List<String> attributes = Arrays.asList("头发", "裤子");

        DecisionTree tree = new DecisionTree();
        Node root = tree.buildTree(data, labels, attributes);

        // 预测
        List<String> instance = Arrays.asList("长", "短");
        String prediction = tree.predict(root, instance);
        System.out.println("Prediction: " + prediction); // 输出预测结果
    }
}
  • Node类

    • Node(String attribute): 用于创建内部节点,保存当前节点的属性。
    • Node(String attribute, String label): 用于创建叶节点,保存当前节点的属性和标签。
    • attribute: 表示当前节点用来分裂数据集的属性。
    • children: 存储当前节点的子节点。
    • label: 如果当前节点是叶节点,则存储该节点的标签。
  • entropy方法

    • entropy(List<String> labels): 计算给定标签集合的熵。
    • labelCounts: 存储每个标签的频次。
    • probability: 每个标签的概率。
    • 熵的计算公式是根据信息论中定义的熵公式。
  • informationGain方法

    • informationGain(List<List<String>> data, List<String> labels, String attribute): 计算数据集在给定属性上的信息增益。
    • index: 属性在数据集中的索引。
    • originalEntropy: 原始数据集的熵。
    • subsets: 按属性值分割的数据子集。
    • newEntropy: 按属性值分割后各子集的熵的加权和。
    • 信息增益是原始熵与新熵的差值。
  • bestAttribute方法

    • bestAttribute(List<List<String>> data, List<String> labels, List<String> attributes): 选择具有最大信息增益的属性。
    • bestAttribute: 最佳分裂属性。
    • bestGain: 最大的信息增益。
  • buildTree方法

    • buildTree(List<List<String>> data, List<String> labels, List<String> attributes): 递归地构建决策树。
    • 如果所有实例的标签都相同,返回叶节点。
    • 如果没有剩余属性,返回出现最多的标签作为叶节点。
    • bestAttribute: 选择最佳分裂属性。
    • subsetssubLabels: 分别存储按最佳属性分割后的数据子集和标签子集。
    • 递归构建子树。
  • predict方法

    • predict(Node node, List<String> instance): 使用构建的决策树对新实例进行预测。
    • 如果当前节点是叶节点,返回标签。
    • 根据当前节点的属性,获取实例中对应的属性值。
    • 递归到子节点进行预测。
  • mostCommonLabel方法

    • mostCommonLabel(List<String> labels): 返回出现最多的类别标签。
    • labelCounts: 存储每个标签的频次。
    • 返回频次最高的标签。
  • main方法

    • 准备示例数据集和标签。
    • 构建决策树。
    • 对新实例进行预测,并输出预测结果。
  • 25
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

沉浮yu大海

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值