【机器学习之决策树详解 (二) 】

1.CART 分类决策树

1.1 Cart树简介

Cart模型是一种决策树模型,它即可以用于分类,也可以用于回归,其学习算法分为下面两步:

(1)决策树生成:用训练数据生成决策树,生成树尽可能大

(2)决策树剪枝:基于损失函数最小化的剪枝,用验证数据对生成的数据进行剪枝。

分类和回归树模型采用不同的最优化策略。Cart回归树使用平方误差最小化策略,Cart分类生成树采用的基尼指数最小化策略。

Scikit-learn中有两类决策树,他们均采用优化的Cart决策树算法。一个是DecisionTreeClassifier一个是DecisionTreeRegressor回归。

1.2 基尼指数计算公式

在这里插入图片描述
①信息增益(ID3)、信息增益率值越大(C4.5),则说明优先选择该特征。

②基尼指数值越小(cart),则说明优先选择该特征。

1.3 基尼指数计算举例

在这里插入图片描述

1.3.1 是否有房

计算过程如下:根据是否有房将目标值划分为两部分:
在这里插入图片描述

1.3.2 婚姻状况

在这里插入图片描述

1.3.3 年收入

先将数值型属性升序排列,以相邻中间值作为待确定分裂点:
在这里插入图片描述
以此类推计算所有分割点的基尼指数,我们发现最小的基尼指数为 0.3。

此时,我们发现:
①以是否有房作为分裂点的基尼指数为:0.343
②以婚姻状况为分裂特征、以 married 作为分裂点的基尼指数为:0.3
③以年收入作为分裂特征、以 97.5 作为分裂点的的基尼指数为:0.3

最小基尼指数有两个分裂点,我们随机选择一个即可,假设婚姻状况,则
可确定决策树如下:
在这里插入图片描述
重复上面步骤,直到每个叶子结点纯度达到最高.

1.4 Cart分类树原理

如果目标变量是离散变量,则是classfication Tree分类树。
分类树是使用树结构算法将数据分成离散类的方法。

(1)分类树两个关键点:

将训练样本进行递归地划分自变量空间进行建树‚用验证数据进行剪枝。

(2)对于离散变量X(x1…xn)处理:

分别取X变量各值的不同组合,将其分到树的左枝或右枝,并对不同组合而产生的树,进行评判,找出最佳组合。如果只有两个取值,直接根据这两个值就可以划分树。取值多于两个的情况就复杂一些了,如变量年纪,其值有“少年”、“中年”、“老年”,则分别生产{少年,中年}和{老年},{少年、老年}和{中年},{中年,老年}和{少年},这三种组合,最后评判对目标区分最佳的组合。因为CART二分的特性,当训练数据具有两个以上的类别,CART需考虑将目标类别合并成两个超类别,这个过程称为双化。这里可以说一个公式,n个属性,可以分出(2^n-2)/2种情况。

CART树生成
输入:数据集 D,特征 A,样本个数阈值、基尼系数阈值
输出:CART决策树T

(1)对于当前节点的数据集为D,如果样本个数小于阈值或者没有特征,则返回决策子树,当前节点停止递归;

(2)计算样本集D的基尼系数,如果基尼系数小于阈值,则返回决策树子树,当前节点停止递归;

(3)计算当前节点现有的各个特征的各个特征值对数据集D的基尼系数;

(4)在计算出来的各个特征的各个特征值对数据集D的基尼系数中,选择基尼系数最小的特征A和对应的特征值α。根据这个最优特征和最优特征D1值,把数据集划分成两部分D1和D2,同时建立当前节点的左右节点,D左节点的数据集为D1,右节点的数据集为D2;

(5)对左右的子节点递归的调用前面1-4步,生成决策树。

CART树剪枝
我们知道,决策树算法对训练集很容易过拟合,导致泛化能力很差,为解决此问题,需要对CART树进行剪枝。CART剪枝算法从“完全生长”的决策树的底端剪去一些子树,使决策树变小,从而能够对未知数据有更准确的预测,也就是说CART使用的是后剪枝法。一般分为两步:先生成决策树,产生所有可能的剪枝后的CART树,然后使用交叉验证来检验各种剪枝的效果,最后选择泛化能力好的剪枝策略。

1.5 使用CART算法构建决策树

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets

iris = datasets.load_iris()
X = iris.data[:,2:]
y = iris.target

from sklearn.tree import DecisionTreeClassifier

#注意:此处传入的是"gini"而不是"entropy",默认criterion='gini'
tree = DecisionTreeClassifier(max_depth=2,criterion="gini")
tree.fit(X,y)

def plot_decision_boundary(model,axis):
    x0,x1 = np.meshgrid(
        np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
        np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1)
    )
    X_new = np.c_[x0.ravel(),x1.ravel()]
    y_predict = model.predict(X_new)
    zz = y_predict.reshape(x0.shape)

    from matplotlib.colors import ListedColormap
    custom_map = ListedColormap(["#EF9A9A","#FFF59D","#90CAF9"])

    plt.contourf(x0,x1,zz,linewidth=5,cmap=custom_map)

plot_decision_boundary(tree,axis=[0.5,7.5,0,3])
plt.scatter(X[y==0,0],X[y==0,1])
plt.scatter(X[y==1,0],X[y==1,1])
plt.scatter(X[y==2,0],X[y==2,1])
plt.show()

在这里插入图片描述
分析上图可知
在这里插入图片描述
①X[1] <=0.8 作为第一次分割的依据,满足条件的所有样本均为同一类别,gini系数为0.667

②X[1]>0.8的,依据 X[1]<=0.75 为划分依据

2. Cart回归决策树

2.1 回归决策树构建原理

CART 回归树和 CART 分类树的不同之处在于:

①CART 分类树预测输出的是一个离散值,CART 回归树预测输出的是一个连续值。
②CART 分类树使用基尼指数作为划分、构建树的依据,CART 回归树使用平方损失。
③分类树使用叶子节点里出现更多次数的类别作为预测类别,回归树则采用叶子节点里均值作为预测输出

CART 回归树构建:
在这里插入图片描述

例子:
假设:数据集只有 1 个特征 x, 目标值值为 y,如下图所示:

x12345678910
y5.565.75.916.46.87.058.98.799.05

由于只有 1 个特征,所以只需要选择该特征的最优划分点,并不需要计算其他特征。
1.先将特征 x 的值排序,并取相邻元素均值作为待划分点,如下图所示:

s1.52.53.54.55.56.57.58.59.510.5

1.计算每一个划分点的平方损失,例如:1.5 的平方损失计算过程为:
R1 为 小于 1.5 的样本个数,样本数量为:1,其输出值为:5.56
R 1 = 5.56 R_1=5.56 R1=5.56
R2 为 大于 1.5 的样本个数,样本数量为:9 ,其输出值为:
R 1 = ( 5.7 + 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 ) / 9 = 7.5 R_1= (5.7 + 5.91+6.4+6.8+7.05+ 8.9+8.7 +9 + 9.05)/9=7.5 R1=5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05)/9=7.5
该划分点的平方损失:
L ( 1.5 ) = ( 5.56 − 5.56 ) 2 + [ ( 5.7 − 7.5 ) 2 + ( 5.91 − 7.5 ) 2 + . . . . ( 9.05 − 7.5 ) 2 ] = 0 + 15.72 = 15.72 L(1.5)=(5.56-5.56)^2+[(5.7-7.5)^2+(5.91-7.5)^2+....(9.05-7.5)^2]=0+15.72=15.72 L(1.5)=(5.565.56)2+[(5.77.5)2+(5.917.5)2+....(9.057.5)2]=0+15.72=15.72
以此方式计算 2.5、3.5… 等划分点的平方损失,结果如下所示:

s1.52.53.54.55.56.57.58.59.5
m(s)15.7212.078.365.783.911.938.0111.7315.74

当划分点 s=6.5 时,m(s) 为1.93最小。因此,第一个划分变量:特征为 X, 切分点为 6.5,即:j=x, s=6.5
在这里插入图片描述
对左子树的 6 个结点计算每个划分点的平方式损失,找出最优划分点:

123456
5.565.75.916.46.87.05
s1.52.53.54.55.5
c15.565.635.725.896.07
c26.376.546.756.937.05
s1.52.53.54.55.5
m(s)1.30870.7540.27710.43681.0644

s=3.5时,m(s) 最小,所以左子树继续以 3.5 进行分裂:

在这里插入图片描述
假设在生成3个区域 之后停止划分,以上就是回归树。每一个叶子结点的输出为:挂在该结点上的所有样本均值。

x12345678910
y5.565.75.916.46.87.058.98.799.05

1号样本真实值 5.56 预测结果:5.72

2号样本真实值是 5.7 预测结果:5.72

3 号样本真实值是 5.91 预测结果 5.72

CART 回归树构建过程如下:

①选择第一个特征,将该特征的值进行排序,取相邻点计算均值作为待划分点
②根据所有划分点,将数据集分成两部分:R1、R2
③R1 和 R2 两部分的平方损失相加作为该切分点平方损失
④取最小的平方损失的划分点,作为当前特征的划分点
⑤以此计算其他特征的最优划分点、以及该划分点对应的损失值
⑥在所有的特征的划分点中,选择出最小平方损失的划分点,作为当前树的分裂点

2 剪枝

2.1 什么是剪枝

剪枝 (pruning)是决策树学习算法对付 过拟合 的主要手段。

在决策树学习中,为了尽可能正确分类训练样本,结点划分过程将不断重复,有时会造成决策树分支过多,这时就可能因训练样本学得"太好"了,以致于把训练集自身的一些特点当作所有数据都具有的一般性质而导致过拟合。因此,可通过主动去掉一些分支来降低过拟合的风险。

2.2 为什么要进行树的剪枝?

决策树是充分考虑了所有的数据点而生成的复杂树,有可能出现过拟合的情况,决策树越复杂,过拟合的程度会越高。

考虑极端的情况:如果我们令所有的叶子节点都只含有一个数据点,那么我们能够保证所有的训练数据都能准确分类,但是很有可能得到高的预测误差,原因是将训练数据中所有的噪声数据都”准确划分”了,强化了噪声数据的作用。

而剪枝修剪分裂前后分类误差相差不大的子树,能够降低决策树的复杂度,降低过拟合出现的概率。

关键步骤解释:

因为决策树的构建过程是一个递归的过层,所以必须确定停止条件,否则过程将不会停止,树会不停生长。通过我们前面的例子我们可以当一个节点下面的所有记录都属于同一类,或者当所有记录属性都具有相同的值时停止,但是这样往往会使得树的节点过多,导致过度拟合的问题。

过度拟合是指直接生成的完全决策树对训练样本的特征描述的“过于精确”,无法实现对新样本进行合理的分许,所以这种情况我们构建的树不是一颗最佳的决策树。

所以,为了避免过拟合,我们引入剪枝技术。

除了剪枝技术我们还有一种解决方法:当前结点中的记录数低于一个最小阈值就停止分裂,采用多数表决的方法决定叶节点的分类。

2.3 如何剪枝?

两种方案:先剪枝和后剪枝

先剪枝说白了就是提前结束决策树的增长,跟上述决策树停止生长的方法一样。

后剪枝是指在决策树生长完成之后再进行剪枝的过程。

2.4. 常见减枝方法汇总

决策树剪枝的基本策略有"预剪枝" (pre-pruning)和"后剪枝"(post- pruning) 。

预剪枝是指在决策树生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分并将当前结点标记为叶结点;

后剪枝则是先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。

2.5 剪枝技术对比

预剪枝优点
预剪枝使决策树的很多分支没有展开,不单降低了过拟合风险,还显著减少了决策树的训练、测试时间开销

预剪枝缺点:
有些分支的当前划分虽不能提升泛化性能,甚至会导致泛化性能降低,但在其基础上进行的后续划分却有可能导致性能的显著提高
预剪枝决策树也带来了欠拟合的风险

后剪枝优点:
比预剪枝保留了更多的分支。一般情况下,后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝

后剪枝缺点:
但后剪枝过程是在生成完全决策树之后进行的,并且要自底向上地对树中所有非叶子节点进行逐一考察,因此在训练时间开销比未剪枝的决策树和预剪枝的决策树都要大得多。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值