决策树
作用: 决策树是一种机器学习算法,用于分类和回归任务。
主要思想是:
通过询问一系列的判断性问题,将一个复杂的问题划分为更简单的子问题来求解。
一个决策树由三个要素组成:
- 结点: 表示一个判断条件。结点可以是根结点、内部结点和叶结点。
- 分支: 从一个结点的判断条件指向子结点(或叶结点)的路径。
- 叶结点: 终止判断条件的结点,对应模型的预测结果。
下面是一棵简单的决策树:
构建决策树的目的: 找到最优的判断条件来将训练数据划分为尽可能纯净的子集。
常用的划分准则有:
- 信息增益: 度量划分后子集的不纯度减少的程度。信息增益越大,表示划分后子集越纯净。
- 基尼指数: 反映了数据集的不纯度,值越小越好。基尼指数主要用于分类树。
- 减少方差: 用于回归树,度量划分后子集的方差减小程度。方差减小幅度越大,表示子集回归均值越接近目标值。
构建决策树的过程是递归进行的,主要步骤:
- 计算每个特征的划分准则(如信息增益),选择最优特征和划分点。
- 根据选择的最优特征和划分点,将训练数据划分为左右两个子集。
- 重复步骤1和2,直到满足停止条件(如结点包含的样本数少于阈值,找到100%同一类样本结点等)。
- 对于终止判断条件的叶结点,设置其分类结果(一般为叶结点中样本数最多的类)或回归结果(均值)。
决策树是一种简单但高效的机器学习算法。但需要防止过拟合,常使用的方法有: 剪枝、设置最小样本数和最大深度等。决策树也存在一定局限性,对outliers和噪声数据比较敏感。
决策树实例
小明大学毕业了来到一家银行当行长,上班第一天就有15位客人申请了贷款,刚刚入行的小明仔细整理了客户的基本信息。
编号 | 有工作 | 有房子 | 信誉 | 贷款结果 |
---|---|---|---|---|
1 | 否 | 否 | 一般 | 拒绝 |
2 | 否 | 否 | 好 | 拒绝 |
3 | 是 | 否 | 好 | 批准 |
4 | 是 | 是 | 一般 | 批准 |
5 | 否 | 否 | 一般 | 拒绝 |
6 | 否 | 否 | 一般 | 拒绝 |
7 | 否 | 否 | 好 | 拒绝 |
8 | 是 | 是 | 好 | 批准 |
9 | 否 | 是 | 非常好 | 批准 |
10 | 否 | 是 | 非常好 | 批准 |
11 | 否 | 是 | 非常好 | 批准 |
12 | 否 | 是 | 好 | 批准 |
13 | 是 | 否 | 好 | 批准 |
14 | 是 | 否 | 非常好 | 批准 |
15 | 否 | 否 | 一般 | 拒绝 |
有没有办法能判断一个客户是否能申请贷款呢?
如果以工作为标准
如果按照少数服从多数为原则的话,结论就是有工作的就会被批准。
即
这显然和样本的结果相悖。
如果以信誉为标准
如果按照少数服从多数为原则的话,结论就是有工作的就会被批准。
即
这显然和样本的结果相悖。
可不可以按照一定顺序使用两个标准来进行分类呢?
当一个样本:
编号 | 有工作 | 有房子 | 信誉 | 贷款结果 |
---|---|---|---|---|
16 | 否 | 否 | 非常好 | ? |
这里的贷款结果就应该是批准。
多个标准的顺序应该怎么样来选择呢?
标准的好坏可以用 基尼系数 定义。含义:标准的不确定程度。
基尼系数:
G
i
n
i
=
1
−
∑
1
k
p
k
2
Gini = 1 - ∑_1^k p_k^2
Gini=1−∑1kpk2
在上述例子中,
G
i
n
i
=
1
−
p
(
批准
)
2
−
p
(
拒绝
)
2
Gini=1-p(批准)^2-p(拒绝)^2
Gini=1−p(批准)2−p(拒绝)2
当p(批准)=1,p(拒绝)=0,则
G
i
n
i
=
1
−
1
−
0
=
0
Gini=1-1-0=0
Gini=1−1−0=0
当p(批准)=0,p(拒绝)=1,则
G
i
n
i
=
1
−
0
−
1
=
0
Gini=1-0-1=0
Gini=1−0−1=0
当p(批准)=0.5,p(拒绝)=0.5,则
G
i
n
i
=
1
−
0.25
−
0.25
=
0.5
Gini=1-0.25-0.25=0.5
Gini=1−0.25−0.25=0.5
当客户一定被拒绝或一定被批准,这种确定性会得到一个接近于0的基尼系数。我们只需要选择基尼系数最小的来作为决策树下一级分类的标准就可以了。
在上述例子中:
1、首先不考虑任何标准,根据贷款结果,直接计算数据的基尼系数。
G
i
n
i
=
1
−
p
(
批准
)
2
−
p
(
拒绝
)
2
=
1
−
(
9
/
15
)
2
−
(
6
/
15
)
2
=
0.48
Gini=1-p(批准)^2 - p(拒绝) ^2=1-(9/15)^2-(6/15)^2=0.48
Gini=1−p(批准)2−p(拒绝)2=1−(9/15)2−(6/15)2=0.48
说明在不考虑任何标准的情况下,数据是类似于随机生成的。
2、再考虑那些有工作的客户,没有工作的客户。
G
i
n
i
(
工作,是
)
=
1
−
p
(
批准
)
2
−
p
(
拒绝
)
2
=
1
−
(
5
/
5
)
2
−
(
0
)
2
=
0
Gini(工作,是)=1-p(批准)^2 - p(拒绝) ^2=1-(5/5)^2-(0)^2=0
Gini(工作,是)=1−p(批准)2−p(拒绝)2=1−(5/5)2−(0)2=0
G
i
n
i
(
工作,否
)
=
1
−
p
(
批准
)
2
−
p
(
拒绝
)
2
=
1
−
(
4
/
10
)
2
−
(
6
/
10
)
2
=
0.48
Gini(工作,否)=1-p(批准)^2 - p(拒绝) ^2=1-(4/10)^2-(6/10)^2=0.48
Gini(工作,否)=1−p(批准)2−p(拒绝)2=1−(4/10)2−(6/10)2=0.48
3、计算以工作为标准分类的基尼系数。
G
i
n
i
(
工作
)
=
5
/
5
G
i
n
i
(
工作,是
)
+
10
/
15
G
i
n
i
(
工作,否
)
=
0.32
Gini(工作)=5/5Gini(工作,是)+10/15Gini(工作,否)=0.32
Gini(工作)=5/5Gini(工作,是)+10/15Gini(工作,否)=0.32
4、以此类推:
G
i
n
i
(
房子
)
=
0.27
Gini(房子) = 0.27
Gini(房子)=0.27
G
i
n
i
(
信誉
)
=
0.28
Gini(信誉) = 0.28
Gini(信誉)=0.28
相比之下,Gini(房子)值最小。首先以是否有房子作为标准构建决策树。
下一级按什么标准划分呢?
G
i
n
i
=
1
−
(
3
/
9
)
2
−
(
6
/
9
)
2
=
0.44
Gini=1-(3/9)^2-(6/9)^2=0.44
Gini=1−(3/9)2−(6/9)2=0.44
G
i
n
i
(
工作
)
=
0
Gini(工作)=0
Gini(工作)=0
G
i
n
i
(
信誉
)
=
0.22
Gini(信誉)=0.22
Gini(信誉)=0.22
相比之下,Gini(工作)值最小。下一级以是否有工作作为标准构建决策树。
CART
- CART是分类与回归树(Classification and Regression Tree)的缩写。它是实现决策树的一种算法。通过递归二分的方式构建树结构,并在叶子节点给出预测输出。
- 决策树作为理论模型,CART算法是该模型的一种实用实现。
决策树是一种基本的机器学习方法,用于分类和回归任务。它通过构建树状结构,将样本空间递归地划分成更小的区域,从而实现预测。
CART算法是实现决策树的常用算法之一。
它的主要步骤包括:
- 选择最优划分特征和划分点: 遍历每个特征和每个可能的划分点,找到使得目标函数最优的特征和划分点。
- 根据最优划分,将当前节点切分为子节点: 子节点继承父节点的样本,但只保留符合划分条件的样本。
- 判断子节点是否符合停止条件: 如果没有更多特征可供选择或样本数量太小,则不再继续切分,将节点标记为叶子节点。
- 重复步骤1-3,直到所有子节点均为叶子节点: 构建完整的树结构。
- 对叶子节点进行分类/回归: 根据节点上样本的多数类别或平均值,确定叶子节点的分类/回归输出。
- 对新样本进行预测: 根据树结构,从根节点开始逐步判断新样本的划分,最终到达叶子节点,得到预测输出。
代码实现
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
# 构造训练集
X = [[0, 0], [1, 1]]
y = [0, 1]
# 构造测试集
X_test = [[2, 2]]
# 建立决策树模型
model = DecisionTreeClassifier()
# 训练模型
model.fit(X, y)
# 对测试集进行预测
y_pred = model.predict(X_test)
print(y_pred) # [1]
# 构建决策树
tree_graph = export_graphviz(model, out_file=None)
# 画出决策树结构
import graphviz
dot_data = export_graphviz(model, out_file=None, feature_names=["x1", "x2"],
class_names=["class0", "class1"],
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph