机器学习——决策树原理以及简单实现

一、决策树模型

1.决策树模型的基本概念

决策树是一种基于树形结构的分类和回归模型。它通过划分数据集来不断地构建树形结构,每个非叶子节点表示一个特征,每个叶子节点表示一个分类或回归的结果。在分类任务中,决策树可以根据样本的特征将其划分到不同的类别中;在回归任务中,决策树可以根据样本的特征预测其数值型输出。

决策树基于树的结构进行决策,从根节点开始,沿着划分属性进行分支,直到叶节点:

  • “内部结点”:有根结点和中间结点,某个属性上的测试(test),这里的test是针对属性进行判断
  • 分支:该测试的可能结果,属性有多少个取值,就有多少个分支
  • “叶节点”:预测结果

如下基于西瓜分类的一颗决策树:

2.决策树的划分算法

决策树学习的关键在于如何选择最优划分属性。一般而言,随着划分过程不断进行,我们希望决策树的分支结点所包含的样本尽可能属于同一类别,即结点的“纯度”(purity)越来越高。

经典的属性划分方法: 信息增益: ID 3 增益率:C 4.5 基尼指数:CART

2.1 信息增益

2.1.1信息熵

  “信息熵”是度量样本集合纯度最常用的一种指标,假定当前样本集合D中第k类样本所占的比例为pk (K=1, 2, ..., |y|) ,则D的信息熵定义为

   Ent(D)的值越小,则D的纯度越高 计算信息熵时约定:若p = 0,则plog2p=0 Ent(D)的最小值为0,最大值为log2|y|。

2.1.2离散属性

离散属性a有V个可能的取值{a1, a2, ..., aV},用a来进行划分,则会产生V个分支结点,其中第v个分支结点包含了D中所有在属性a上取值为av的样本,记为Dv。则可计算出用属性a对样本集D进行划分所获得的“信息增益”:

 一般而言,信息增益越大,则意味着使用属性a来进行划分所获得的“纯度提升”越大

2.2 增益率

可定义增益率:

   

其中

    称为属性a的“固有值” [Quinlan, 1993] ,属性a的可能取值数目越多(即V越大),则IV(a)的值通常就越大.

2.3 基尼指数

分类问题中,假设D有K个类,样本点属于第k类的概率为p_k,则概率分布的基尼值定义为:

Gini(D)越小,数据集D的纯度越高;

给定数据集D,属性a的基尼指数定义为:

在候选属性集合A中,选择那个使得划分后基尼指数最小的属性作为最有划分属性。

3.构建过程

构建决策树的过程可以概括为以下四个步骤:

  1. 特征选择:从所有特征中选择一个最优特征进行划分。常见的特征选择标准有信息增益(Information Gain)、信息增益比(Gain Ratio)、基尼指数(Gini Index)等。
  2. 决策树生成:根据选择的特征,将数据集划分为若干个子集。为每个子集生成对应的子节点,并将这些子节点作为当前节点的分支。对每个子节点,重复第1步和第2步,直到满足停止条件。
  3. 停止条件:当满足以下任一条件时,停止决策树的生成:
  • 所有特征已经被用于划分;
  • 所有子集中的样本都属于同一类别;
  • 子集中样本数量不足以继续划分。

  1. 剪枝:为了避免过拟合(Overfitting),可以对生成的决策树进行剪枝。常见的剪枝方法有预剪枝(Pre-pruning)和后剪枝(Post-pruning)。

参考自机器学习经典算法-决策树 - 知乎 (zhihu.com)

4.剪枝处理

4.1目的

“剪枝”是决策树学习算法对付“过拟合”的主要手段。

可通过“剪枝”来一定程度避免因决策分支过多,以致于把训练集自身的一些特点当做所有数据都具有的一般性质而导致的过拟合。

4.2剪枝方法

4.2.1预剪枝

 1当决策树达到预设的高度时就停止决策树的生长。

2. 达到某个节点的实例集具有相同的特征向量(属性取值相同),即使这些实例不属于同一类,也可以停止决策树的生长。  

3.定义一个阈值,当达到某个节点的实例个数小于阈值时就可以停止决策树的生长。  通过计算每次扩张对系统性能的增益,决定是否停止决策树的生长。

4.通过计算每次扩张对系统性能的增益,决定是否停止决策树的生长。

4.2.2后剪枝

先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行分析计算,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。

优点 :后剪枝比预剪枝保留了更多的分支,欠拟合风险小,泛化性能往往优于预剪枝决策树。

缺点: 训练时间开销大:后剪枝过程是在生成完全决策树之后进行的,需要自底向上对所有非叶结点逐一计算。

二,决策树基本实现

1.示例

我使用了鸢尾花数据集作为示例数据。

  1. 导入必要的库:
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

这些库是用于实现决策树算法所必需的。datasets模块提供了一些经典的数据集,而train_test_split函数用于将数据集拆分为训练集和测试集。DecisionTreeClassifier类是scikit-learn库中用于实现决策树分类器的类。最后,accuracy_score函数用于计算模型的准确率。

    2.加载数据集:

iris = datasets.load_iris()
X = iris.data
y = iris.target

这里使用了鸢尾花数据集作为示例数据。通过load_iris函数加载数据集,并将特征数据赋值给X,将目标变量赋值给y

3.拆分数据集:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

使用train_test_split函数将数据集按照指定的比例(本例中为80%训练集和20%测试集)进行拆分,并将拆分后的数据赋值给相应的变量。random_state参数用于设置随机种子,以确保每次运行时得到相同的拆分结果。

4.创建决策树分类器:

clf = DecisionTreeClassifier()

使用DecisionTreeClassifier类创建一个决策树分类器对象。这里没有传入任何参数,使用默认的参数设置。

5.训练模型:

clf.fit(X_train, y_train)

使用训练集数据和标签训练决策树模型。fit方法会根据数据学习特征之间的关系,并构建决策树模型。

6.进行预测:

y_pred = clf.predict(X_test)

使用训练好的模型对测试集数据进行预测,将预测结果赋值给y_pred

7.计算准确率:

accuracy = accuracy_score(y_test, y_pred)
print("准确率:", accuracy)

使用accuracy_score函数计算模型在测试集上的准确率,并将结果打印出来。accuracy_score函数将真实标签y_test和预测标签y_pred作为输入,通过比较它们的一致性来计算准确率。

以上代码实现了一个简单的决策树模型,在鸢尾花数据集上进行训练和测试,并输出了模型的准确率。

结果如下:

2.绘制决策树结构图

在原来代码上修改绘制决策树结构图如下:

plt.figure(figsize=(10, 10))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True, rounded=True)
plt.show()

得到结果如下:

三,总结

1.首先学习了解到了决策树是一种直观、易于理解和解释的机器学习算法。它通过构建一棵树形结构来选择最佳的决策路径,每个节点代表一个特征,每个边代表一个决策规则。可以直观地了解模型是如何做出预测的。

2.延续了前几次实验的基本步骤将数据集划分为训练集和测试集,并使用训练集对决策树模型进行训练,然后在测试集上进行预测。通过计算准确率去评估模型的性能。这能更好更深入的理解机器学习这门课程的知识。

3.在以上代码实现测试时遇到了可视化决策树的结构图的问题,最开始想通过使export_graphviz函数将生成的图像保存为PDF格式文件,命名为iris.pdf,并在程序所在的目录下找到这个文件。将打开系统默认的图像查看器,并显示决策树。代码如下:

# 将训练好的决策树模型导出为Graphviz格式
dot_data = export_graphviz(clf, out_file=None, 
                           feature_names=iris.feature_names,  
                           class_names=iris.target_names,  
                           filled=True, rounded=True,  
                           special_characters=True)  
graph = graphviz.Source(dot_data)  
graph.render("iris")  

# 查看决策树
graph.view()

但是始终有:

这类错误出现,尝试将Graphviz的bin目录添加到系统的环境变量中等多种方式无果后,选择以上的方式绘制决策树。这是本次实验还未能理解与应用的方法,需要多加尝试。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值