机器学习第二章:决策树

本文介绍了决策树算法的基本概念,包括分类树和回归树的区别,构建过程中的属性选择和纯度度量(如信息熵和Gini系数)。通过ID3、C4.5和CART算法的比较,展示了鸢尾花数据分类案例,并讨论了决策树深度对过拟合的影响。
摘要由CSDN通过智能技术生成

1.1决策树算法简介

(1)决策树算法概念

定义:决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构建决策树来
进行分析的一种方式,是一种直观应用概率分析的一种图解法;决策树是一种预
测模型,代表的是对象属性与对象值之间的映射关系;决策树是一种树形结构,
其中每个内部节点表示一个属性的测试,每个分支表示一个测试输出,每个叶节
点代表一种类别;决策树是一种非常常用的有监督的分类算法。

决策树的决策过程就是从根节点开始,测试待分类项中对应的特征属性,并按照
其值选择输出分支,直到叶子节点,将叶子节点的存放的类别作为决策结果。

决策树分为两大类:分类树和回归树,前者用于分类标签值,后者用于预测连续
值,常用算法有ID3、C4.5、CART等

(2)决策树构建过程

决策树算法的重点就是决策树的构造;决策树的构造就是进行属性选择度量,确定各个特征
属性之间的拓扑结构(树结构);构建决策树的关键步骤就是分裂属性,分裂属性是指在某个节
点按照某一类特征属性的不同划分构建不同的分支,其目标就是让各个分裂子集尽可能的'纯
'(让一个分裂子类中待分类的项尽可能的属于同一个类别)。
构建步骤如下:
1. 将所有的特征看成一个一个的节点;
2. 遍历每个特征的每一种分割方式,找到最好的分割点;将数据划分为不同的子节点,eg: N1、
N2....Nm;计算划分之后所有子节点的'纯度'信息;
3. 对第二步产生的分割,选择出最优的特征以及最优的划分方式;得出最终的子节点: N1、N2....Nm
4. 对子节点N1、N2....Nm分别继续执行2-3步,直到每个最终的子节点都足够'纯'。

  • 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
  • 缺点:可能会产生过度匹配的问题
  • 适用数据类型:数值型和标称型

1.2信息熵(Entropy)

信息熵:1948年,香农引入信息熵;一个系统越是有序,信息熵就越低,一个系统越是混乱,信息熵就越高,所以信息熵被认为是一个系统有序程度的度量。信息熵就是用来描述系统信息量的不确定度。

H\left ( X \right )=-\sum_{i=1}^{m}p_{i}log_{2}^{(p_{i})}

1.3决策树量化纯度

决策树的构建是基于样本概率和纯度进行构建操作的,那么进行判断数据集是否
“纯”可以通过三个公式进行判断,分别是Gini系数、熵(Entropy)、错误率,这
三个公式值越大,表示数据越“不纯”;越小表示越“纯”;实践证明这三种公
式效果差不多,一般情况使用熵公式

Gini=1-\sum_{i=1}^{n}p(i)^{2}

H\left ( X \right )=-\sum_{i=1}^{m}p_{i}log_{2}^{(p_{i})}

Error=1-max\left \{ p(i) \right \}

当计算出各个特征属性的量化纯度值后使用信息增益度来选择出当前数据集的分
割特征属性;如果信息增益度的值越大,表示在该特征属性上会损失的纯度越大 ,
那么该属性就越应该在决策树的上层,计算公式为:

Gain=\Delta =H(D)-H(D|A)

Gain为A为特征对训练数据集D的信息增益,它为集合D的经验熵H(D)与特征A给
定条件下D的经验条件熵H(D|A)之差

1.4三种决策树算法

(1)ID3算法

ID3算法是决策树的一个经典的构造算法,内部使用信息熵以及信息增益来进行
构建;每次迭代选择信息增益最大的特征属性作为分割属性。

H\left ( X \right )=-\sum_{i=1}^{m}p_{i}log_{2}^{(p_{i})}

Gain=\Delta =H(D)-H(D|A)

优点:
决策树构建速度快;实现简单;
缺点:
计算依赖于特征数目较多的特征,而属性值最多的属性并不一定最优
ID3算法不是递增算法
ID3算法是单变量决策树,对于特征属性之间的关系不会考虑
抗噪性差
只适合小规模数据集,需要将数据放到内存中

(2)C4.5算法

在ID3算法的基础上,进行算法优化提出的一种算法(C4.5);现在C4.5已经是特
别经典的一种决策树构造算法;使用信息增益率来取代ID3算法中的信息增益,
在树的构造过程中会进行剪枝操作进行优化;能够自动完成对连续属性的离散化
处理;C4.5算法在选中分割属性的时候选择信息增益率最大的属性,涉及到的公
式为:

H\left ( X \right )=-\sum_{i=1}^{m}p_{i}log_{2}^{(p_{i})}

Gain=\Delta =H(D)-H(D|A)

Gainratio(A)=\frac{Gain(A)}{H(A)}

优点:
产生的规则易于理解
准确率较高
实现简单
缺点:
对数据集需要进行多次顺序扫描和排序,所以效率较低
只适合小规模数据集,需要将数据放到内存中

(3)CART算法

使用基尼系数作为数据纯度的量化指标来构建的决策树算法就叫做
CART(Classification And Regression Tree,分类回归树)算法。CART算法使用
GINI增益作为分割属性选择的标准,选择GINI增益最大的作为当前数据集的分
割属性;可用于分类和回归两类问题。强调备注:CART构建是二叉树。

Gini=1-\sum_{i=1}^{n}p(i)^{2}

Gain=\Delta =H(D)-H(D|A)

1.5决策树案例一:鸢尾花数据分类

数据集:iris.csv_免费高速下载|百度网盘-分享无限制

使用决策树算法API对鸢尾花数据进行分类操作,并理解及进行决策树API的相关
参数优化

from six import StringIO
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import matplotlib as mpl
import pydotplus
 
 
iris = load_iris()
data = pd.DataFrame(iris.data)
data.columns = iris.feature_names
data['Species'] = load_iris().target
 
# 准备数据
x = data.iloc[:, 0:4]
y = data.iloc[:, -1]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
# 训练决策树模型
tree_clf = DecisionTreeClassifier(max_depth=8, criterion='gini')
tree_clf.fit(x_train, y_train)
 
# 用测试集进行预测,得出精确率
y_test_hat = tree_clf.predict(x_test)
print("acc score:", accuracy_score(y_test, y_test_hat))
print(tree_clf.feature_importances_)
 
# 将决策树保存成图片
dot_data = StringIO()
tree.export_graphviz(
    tree_clf,
    out_file=dot_data,
    feature_names=iris.feature_names[:],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png('tree.png')

acc score: 1.0
[0.03575134 0.         0.88187037 0.08237829]

 我们看出用测试集预测的精确度为100%,特征中,第三个特征的重要程度系数最大,说明它对决策分类起很重要的作用

修改最大深度

depth = np.arange(1, 15)
err_list = []
for d in depth:
    print(d)
    clf = DecisionTreeClassifier(criterion='gini', max_depth=d)
    clf.fit(x_train, y_train)
    y_test_hat = clf.predict(x_test)
    result = (y_test_hat == y_test)
    if d == 1:
        print(result)
    err = 1 - np.mean(result)
    print(100 * err)
    err_list.append(err)
    print(d, ' 错误率:%.2f%%' % (100 * err))
 
mpl.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(facecolor='w')
plt.plot(depth, err_list, 'ro-', lw=2)
plt.xlabel('决策树深度', fontsize=15)
plt.ylabel('错误率', fontsize=15)
plt.title('决策树深度和过拟合', fontsize=18)
plt.grid(True)
plt.show()

 

 从图中我们看出当决策树深度等于3的时候,错误率就以及接近0了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值