引言
决策树算法是机器学习领域中一种经典且常用的算法,它通过构建一棵树形结构来对数据进行分类或回归预测。决策树算法具有直观、易于理解和解释的特点,广泛应用于各种领域,如医学诊断、金融风险评估和客户分类等。本篇博客将介绍决策树算法的基础知识,包括算法原理、特征选择准则和树的构建过程,并提供一个示例来演示如何应用决策树算法。
目录
1、决策树简介
决策树算法是一种监督学习算法,适用于分类和回归问题。它以树状结构表示决策规则,每个内部节点代表一个特征属性,每个叶子节点代表一个类别或一个预测值。决策树算法通过对数据进行划分,使得每个划分后的子集尽可能地纯净,即同一子集内的样本属于同一类别或具有相似的预测值。
2、决策树算法
决策树的基本原理是通过选择最优特征进行数据划分,使得划分后的子集尽可能地纯净。算法的核心步骤包括:选择最优特征、划分数据集和递归构建子树。通过不断重复这些步骤,直到满足终止条件,即达到预定的停止条件或无法继续划分时,构建出一棵完整的决策树。
2.1特征选择准则:
这里提出三种典型的决策树算法:ID3算法、C4.5算法、CART算法
ID3-信息增益(information gain)
ID3算法是一种基于信息增益进行特征选择的决策树算法。信息增益是一种衡量特征对数据划分贡献程度的指标,表示通过使用某个特征进行划分后,得到的子集纯度的提升程度。
对于给定的数据集D,计算其熵(entropy)。熵是用来度量数据集的混乱程度,计算公式为:
其中,p(x)表示数据集中属于类别x的样本在数据集中的比例。
对于每个特征A,计算其条件熵(conditional entropy)。条件熵是在特征A已知的情况下,数据集D的熵,计算公式为:
其中,Dv表示根据特征A划分后的第v个子集,|Dv|表示子集Dv的样本数量,|D|表示整个数据集D的样本数量。
计算特征A的信息增益(Information Gain),即数据集D的熵减去特征A的条件熵:
C4.5-增益率(gain ratio)
C4.5-增益率(gain ratio)是ID3算法的改进版本。C4.5算法在选择最优特征进行划分时,使用增益率(gain ratio)作为衡量特征重要性的准则。
增益率是在信息增益的基础上引入了一个比率因子,用来解决信息增益对取值数目较多的特征有偏好的问题。增益率通过对信息增益进行归一化来惩罚具有更多取值的特征。
增益率的计算公式如下:
其中,Gain(A)表示特征A的信息增益,SplitInfo(A)表示特征A的分裂信息,计算公式为:
其中,p(t)表示特征A的第t个取值在数据集中的比例。
最终,在选择最优特征时,C4.5算法会计算所有特征的增益率,并选择增益率最大的特征作为划分节点。
CART-基尼系数
CART算法使用基尼系数(Gini index)来选择最优的特征进行划分。
基尼系数是用来衡量一个随机变量的不确定性的指标,计算公式为:
其中,p表示每个类别在样本中出现的概率,pi表示属于第i个类别的样本占总样本数的比例。基尼系数越小,表示数据集中的样本越倾向于属于同一个类别。
在CART算法中,选择最优的特征时,会使用基尼系数来计算划分后的子集的纯度,并选择使得基尼系数最小的特征作为划分节点。具体而言,对于给定的数据集D和特征A,可以将数据集D划分成两个子集D1和D2,分别包含特征A的不同取值。那么,划分后的基尼系数为:
其中,|D1|和|D2|分别为子集D1和D2的样本数量,Gini(D1)和Gini(D2)分别为子集D1和D2的基尼系数。
选择基尼系数最小的特征作为划分节点,可以使得每个子集的基尼系数最小,从而提高决策树的准确性。
三种方法的特点
ID3-信息增益(information gain):
ID3算法存在一些限制。它倾向于选择取值较多的特征,因为这样的特征可以带来更多的划分可能性,可能导致过拟合问题。此外,ID3算法不能处理连续型特征,只能处理离散型特征。
C4.5-增益率(gain ratio):
相比于ID3算法,C4.5算法能够处理连续型特征,并且通过增益率的引入,可以避免对取值较多的特征的偏好。然而,增益率的计算需要额外考虑特征的取值数目,增加了一定的计算开销。此外,增益率对于取值较少的特征可能会偏向较多取值的特征。
CART-基尼系数:
与信息增益不同,基尼系数并不考虑变量的取值次数,因此对于取值较多的特征,基尼系数和信息增益可能会得出不同的结论。此外,基尼系数也不能处理连续型特征,只能处理离散型特征。
2.2 剪枝处理
剪枝是决策树学习算法中的一种常用技术,它可以有效地防止过拟合的发生。
决策树生成时,为了使决策树的复杂度适当,通常会进行剪枝操作。决策树剪枝分为预剪枝和后剪枝两种方式。
1.预剪枝
在决策树生成过程中,每次划分节点前都先进行估计,若当前节点无法带来决策树泛化性能的提升,则停止划分并将该节点标记为叶节点。判断当前节点是否需要停止划分,可以通过交叉验证、信息增益等方法进行估计。
预剪枝的优点在于可以减少决策树的深度,节约计算资源,同时也可以避免过拟合。然而,预剪枝可能会导致欠拟合,因为它限制了决策树的生长,使得一些重要的特征被忽略掉。
2.后剪枝
决策树生成完毕后,自底向上对非叶节点进行考察,若将该节点对应的子树替换成叶节点能带来决策树泛化性能的提升,则将该子树替换成叶节点。判断该节点是否应该替换,可以通过验证集、复杂度惩罚等方法进行估计。
后剪枝的优点在于它不会限制决策树的生长,因此可以更好地利用重要的特征。同时,后剪枝也可以减少过拟合现象,并且在决策树生成完毕后进行,其结果相对稳定。
3.注意
的是,在进行剪枝时,应该尽量避免剪枝过多,否则可能会使得决策树欠拟合。此外,剪枝的效果也取决于所使用的剪枝方法和超参数的设定。
3.决策树优缺点
优点:
-
直观易懂:决策树可以生成清晰的规则集,易于理解和解释。它们提供了一种直观的方式来解释数据中的关联和模式。
-
适应多类型数据:决策树可以处理离散型和连续型的特征,也可以处理具有多类别输出的分类问题。
-
特征选择:决策树可以自动选择最重要的特征,并用于构建决策规则。这使得决策树在处理高维数据和特征选择方面具有优势。
-
可解释性:由于决策树的规则清晰可见,它们提供了对决策过程的可解释性,可以帮助人们理解模型的决策依据。
缺点:
-
容易过拟合:当决策树生长过深时,容易学习到训练数据的细节和噪声,导致模型过拟合,泛化能力较差。剪枝技术可以用来缓解这个问题。
-
不稳定性:对于数据中微小的变动,决策树可能会生成完全不同的树结构。这使得决策树在数据集稍有变化时,具有较高的不稳定性。
-
忽略特征间的相关性:决策树的每个节点只考虑单个特征,忽略了特征之间的相关性。这可能导致模型无法捕捉到特征之间的复杂关系。
-
数据不平衡问题:在处理具有不平衡类别分布的数据时,决策树容易偏向具有更多样本的类别,导致分类性能不佳。
-
高计算复杂度:在大规模数据集上,决策树的训练和预测计算复杂度较高,需要消耗较多的时间和计算资源。
4.决策树的应用示例
这里我们使用鸢尾花数据集来实现决策树,鸢尾花数据集中包含花萼长度、花萼宽度、花瓣长度和花瓣宽度等特征,以及对应的鸢尾花品种。我们可以使用决策树算法对这些特征进行分析,并根据特征的划分对鸢尾花进行分类。
决策树的构建过程:
决策树的构建是一个递归的过程。从根节点开始,选择最优特征进行划分,然后根据划分结果构建子树。对每个子集重复这个过程,直到满足终止条件。常用的终止条件包括:达到预定的停止条件、无法继续划分或子集中的样本属于同一类别。
决策树构建过程如下:
def build_decision_tree(X, y):
if len(np.unique(y)) == 1:
# 如果所有样本都属于同一个类别,则创建叶节点并返回
return DecisionTreeNode(class_label=y[0])
num_features = X.shape[1]
best_feature_index = None
best_threshold = None
best_gini = float('inf')
for feature_index in range(num_features):
thresholds = np.unique(X[:, feature_index])
for threshold in thresholds:
left_indices = X[:, feature_index] <= threshold
right_indices = X[:, feature_index] > threshold
gini = (gini_impurity(y[left_indices]) * sum(left_indices) +
gini_impurity(y[right_indices]) * sum(right_indices)) / len(y)
if gini < best_gini:
best_gini = gini
best_feature_index = feature_index
best_threshold = threshold
if best_feature_index is None:
# 如果无法再分割样本,则创建叶节点并返回
return DecisionTreeNode(class_label=most_common_class(y))
node = DecisionTreeNode(feature_index=best_feature_index, threshold=best_threshold)
left_indices = X[:, best_feature_index] <= best_threshold
right_indices = X[:, best_feature_index] > best_threshold
node.left = build_decision_tree(X[left_indices], y[left_indices])
node.right = build_decision_tree(X[right_indices], y[right_indices])
return node
def gini_impurity(y):
classes, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
gini = 1 - np.sum(probabilities**2)
return gini
def most_common_class(y):
classes, counts = np.unique(y, return_counts=True)
most_common_index = np.argmax(counts)
most_common_class = classes[most_common_index]
return most_common_class
def plot_decision_tree(node, depth=0, position=0, parent_position=0):
x = (position + parent_position) / 2
y = depth * -1
if node.class_label is not None:
plt.annotate(f"{node.class_label}", (x, y), xycoords='data',
xytext=(0, -20), textcoords='offset points',
fontsize=12, ha='center', va='center')
else:
plt.annotate(f"{feature_names[node.feature_index]} <= {node.threshold}", (x, y), xycoords='data',
xytext=(0, -20), textcoords='offset points',
fontsize=12, ha='center', va='center')
plot_decision_tree(node.left, depth - 1, position - (2 ** depth), x)
plot_decision_tree(node.right, depth - 1, position + (2 ** depth), x)
这段代码实现了一个决策树的构建过程,包含以下几个核心函数:
-
build_decision_tree(X, y)
函数:该函数的输入为一个样本矩阵 X 和对应的目标变量数组 y,输出为一个 DecisionTreeNode 对象,代表构建出的决策树。该函数首先判断当前样本是否只属于同一类别,如果是,则创建叶节点并返回;否则,通过计算所有可能的特征和阈值组合的 Gini 不纯度(gini_impurity),找到最优的特征和阈值组合,创建一个决策节点,并递归调用自身,对左右子树进行分裂,直到无法再分裂。最后返回整棵树的根节点对象。 -
gini_impurity(y)
函数:该函数计算给定目标变量数组 y 的 Gini 不纯度。 -
most_common_class(y)
函数:该函数统计目标变量数组 y 中出现频率最高的类别。 -
plot_decision_tree(node, depth, position, parent_position)
函数:该函数将决策树对象 node 可视化,其中 depth 表示当前深度,position 表示当前节点的位置,parent_position 则表示父节点的位置。该函数使用 plt.annotate() 函数在图中添加文本标签,展示决策规则。如果当前节点不是叶节点,则会递归调用自身,将左右子树进行可视化。
使用鸢尾花数据集进行的决策树的应用示例并且得到决策树图形
决策树构建过程如上。这里为了方便显示以及画图,不再使用具体的过程构建决策树,而是使用sklearn库中的DecisionTreeClassifier来实现决策树算法。
需要注意的是
sklearn.tree.DecisionTreeClassifier()
函数默认使用的是 Gini 不纯度(Gini impurity)作为分裂标准进行决策树的构建。此外,该函数也提供了criterion
参数,可以通过设置该参数来改变分裂标准,可选的值包括 "gini" 和 "entropy",分别对应 Gini 不纯度和信息熵(Entropy)。例如,若需要使用信息熵作为分裂标准,则可以将criterion
参数设置为 "entropy"。
使用scikit-learn库来完成
import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = list(iris.target_names)
# 创建决策树分类器
clf = DecisionTreeClassifier()
clf.fit(X, y)
# 绘制决策树
plt.figure(figsize=(12, 8))
plot_tree(clf, feature_names=feature_names, class_names=target_names, filled=True)
plt.show()
运行结果: