一、决策树算法概述
决策树是一种常用的分类与回归方法,适用于处理非线性关系的数据。它通过一系列的分裂操作,将数据集划分成不同的子集,从而实现数据分类或回归。决策树的优点包括易于理解和解释,处理高维数据的能力强,并且可以处理缺失数据。本文将详细介绍决策树的原理,并通过案例代码实现决策树分类。
二、决策树的原理
1. 树的结构
决策树由节点和边组成:
- 根节点:树的起始点,代表整个数据集。
- 内部节点:每个内部节点表示对某个属性的测试,并根据测试结果将数据集划分为不同的子集。
- 叶节点:表示分类结果或回归值。
2. 树的生成
决策树的生成过程可以通过递归地选择最佳分裂点来构建。这通常包括以下步骤:
- 选择最佳属性:选择能够最大化信息增益的属性进行分裂。
- 分裂数据集:根据选定的属性,将数据集分裂为不同的子集。
- 递归构建子树:对子集递归地应用上述步骤,直到满足停止条件(如所有实例属于同一类别,属性用尽,或达到最大深度)。
3. 信息增益
信息增益是选择最佳分裂属性的常用标准。信息增益基于熵(Entropy)的概念,熵用于衡量数据集的纯度。信息增益计算公式如下:
1. 数据准备
Iris数据集包含150个样本,每个样本有四个特征(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和一个类别标签(Setosa、Versicolour、Virginica)。
2. 代码实现
示例代码:决策树分类
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
class DecisionTree {
// 节点类
class Node {
String attribute;
Map<String, Node> children = new HashMap<>();
String label;
Node(String attribute) {
this.attribute = attribute;
}
Node(String attribute, String label) {
this.attribute = attribute;
this.label = label;
}
}
private Node root;
// 计算熵
private double entropy(List<String> labels) {
Map<String, Integer> labelCounts = new HashMap<>();
for (String label : labels) {
labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
}
double entropy = 0.0;
for (int count : labelCounts.values()) {
double probability = (double) count / labels.size();
entropy -= probability * Math.log(probability) / Math.log(2);
}
return entropy;
}
// 计算信息增益
private double informationGain(List<List<String>> data, List<String> labels, String attribute) {
int index = data.get(0).indexOf(attribute);
double originalEntropy = entropy(labels);
Map<String, List<String>> subsets = new HashMap<>();
for (int i = 0; i < data.size(); i++) {
String key = data.get(i).get(index);
subsets.putIfAbsent(key, new ArrayList<>());
subsets.get(key).add(labels.get(i));
}
double newEntropy = 0.0;
for (List<String> subset : subsets.values()) {
newEntropy += ((double) subset.size() / labels.size()) * entropy(subset);
}
return originalEntropy - newEntropy;
}
// 选择最佳分裂属性
private String bestAttribute(List<List<String>> data, List<String> labels, List<String> attributes) {
String bestAttribute = null;
double bestGain = Double.NEGATIVE_INFINITY;
for (String attribute : attributes) {
double gain = informationGain(data, labels, attribute);
if (gain > bestGain) {
bestGain = gain;
bestAttribute = attribute;
}
}
return bestAttribute;
}
// 生成决策树
public Node buildTree(List<List<String>> data, List<String> labels, List<String> attributes) {
if (new HashSet<>(labels).size() == 1) {
return new Node(null, labels.get(0)); // 只有一个类别,返回叶节点
}
if (attributes.isEmpty()) {
return new Node(null, mostCommonLabel(labels)); // 属性用尽,返回多数类别
}
String bestAttribute = bestAttribute(data, labels, attributes);
Node node = new Node(bestAttribute);
Map<String, List<List<String>>> subsets = new HashMap<>();
Map<String, List<String>> subLabels = new HashMap<>();
for (int i = 0; i < data.size(); i++) {
String key = data.get(i).get(data.get(0).indexOf(bestAttribute));
subsets.putIfAbsent(key, new ArrayList<>());
subLabels.putIfAbsent(key, new ArrayList<>());
subsets.get(key).add(data.get(i));
subLabels.get(key).add(labels.get(i));
}
for (String value : subsets.keySet()) {
List<String> subAttributes = new ArrayList<>(attributes);
subAttributes.remove(bestAttribute);
node.children.put(value, buildTree(subsets.get(value), subLabels.get(value), subAttributes));
}
return node;
}
// 预测
public String predict(Node node, List<String> instance) {
if (node.label != null) {
return node.label;
}
String attribute = node.attribute;
String value = instance.get(data.get(0).indexOf(attribute));
Node childNode = node.children.get(value);
if (childNode == null) {
return "Unknown"; // 无法预测
}
return predict(childNode, instance);
}
// 返回出现最多的类别标签
private String mostCommonLabel(List<String> labels) {
Map<String, Integer> labelCounts = new HashMap<>();
for (String label : labels) {
labelCounts.put(label, labelCounts.getOrDefault(label, 0) + 1);
}
String mostCommon = null;
int maxCount = 0;
for (Map.Entry<String, Integer> entry : labelCounts.entrySet()) {
if (entry.getValue() > maxCount) {
maxCount = entry.getValue();
mostCommon = entry.getKey();
}
}
return mostCommon;
}
public static void main(String[] args) {
// 示例数据集
List<List<String>> data = Arrays.asList(
Arrays.asList("长", "长", "女"),
Arrays.asList("短", "短", "男"),
Arrays.asList("长", "短", "女"),
Arrays.asList("短", "长", "男"),
Arrays.asList("长", "长", "女"),
Arrays.asList("长", "短", "男")
);
List<String> labels = Arrays.asList("女生", "男生", "女生", "男生", "女生", "男生");
List<String> attributes = Arrays.asList("头发", "裤子");
DecisionTree tree = new DecisionTree();
Node root = tree.buildTree(data, labels, attributes);
// 预测
List<String> instance = Arrays.asList("长", "短");
String prediction = tree.predict(root, instance);
System.out.println("Prediction: " + prediction); // 输出预测结果
}
}
-
Node类
Node(String attribute)
: 用于创建内部节点,保存当前节点的属性。Node(String attribute, String label)
: 用于创建叶节点,保存当前节点的属性和标签。attribute
: 表示当前节点用来分裂数据集的属性。children
: 存储当前节点的子节点。label
: 如果当前节点是叶节点,则存储该节点的标签。
-
entropy方法
entropy(List<String> labels)
: 计算给定标签集合的熵。labelCounts
: 存储每个标签的频次。probability
: 每个标签的概率。- 熵的计算公式是根据信息论中定义的熵公式。
-
informationGain方法
informationGain(List<List<String>> data, List<String> labels, String attribute)
: 计算数据集在给定属性上的信息增益。index
: 属性在数据集中的索引。originalEntropy
: 原始数据集的熵。subsets
: 按属性值分割的数据子集。newEntropy
: 按属性值分割后各子集的熵的加权和。- 信息增益是原始熵与新熵的差值。
-
bestAttribute方法
bestAttribute(List<List<String>> data, List<String> labels, List<String> attributes)
: 选择具有最大信息增益的属性。bestAttribute
: 最佳分裂属性。bestGain
: 最大的信息增益。
-
buildTree方法
buildTree(List<List<String>> data, List<String> labels, List<String> attributes)
: 递归地构建决策树。- 如果所有实例的标签都相同,返回叶节点。
- 如果没有剩余属性,返回出现最多的标签作为叶节点。
bestAttribute
: 选择最佳分裂属性。subsets
和subLabels
: 分别存储按最佳属性分割后的数据子集和标签子集。- 递归构建子树。
-
predict方法
predict(Node node, List<String> instance)
: 使用构建的决策树对新实例进行预测。- 如果当前节点是叶节点,返回标签。
- 根据当前节点的属性,获取实例中对应的属性值。
- 递归到子节点进行预测。
-
mostCommonLabel方法
mostCommonLabel(List<String> labels)
: 返回出现最多的类别标签。labelCounts
: 存储每个标签的频次。- 返回频次最高的标签。
-
main方法
- 准备示例数据集和标签。
- 构建决策树。
- 对新实例进行预测,并输出预测结果。