CART构建与剪枝

本文介绍了CART算法,它是ID3算法的改进版,采用二元切分处理连续数据,解决了ID3的局限。文章通过实例展示了CART算法构建决策树的过程,并讨论了其过拟合问题,提出剪枝方法来提高泛化能力。同时,提供了不同数据集上的CART决策树构建示例,以及剪枝前后的对比。
摘要由CSDN通过智能技术生成

上周实现了离散变量的决策树的构建(ID3算法),它的做法是每次选取当前最佳的特征来分割数据,并按照该特征所有的可能值来切分。也就是说,如果一个特征有4种取值,那么数据被切分成4份,一旦按某特征切分后,便固定死了,该特征在之后的算法执行过程中将不会再起作用,显然,这种切分方式过于迅速。而此外,ID3算法不能直接处理连续型特征。
再补充一下用ID3算法生成决策树的图例。
我们的例子是李航的《统计学习方法》第五章的表5.1,根据该表生成决策树,在已知年龄、有工作、有自己房子、信贷情况的情况下判断是否给贷款.
图1 贷款申请样本数据表
用ID3算法生成的决策树如下(画图的程序实现在最后,参照的是Peter Harrington的《机器学习实战》):
图2 ID3算法生成的贷款决策树
效果很明显,从杂乱无章的15条记录中提取出这么精辟的决策树,有了这棵决策树便很轻易的可以判断该不该给某人贷款,如果他有房子,就给贷,如果没有,但他有工作,也给贷,如果都没有,就不给贷。比表5.1精简有效多了。
再来看一个例子,周志华的《机器学习》的判断是否为好瓜的数据:
图3 判断是否为好瓜
判断一个西瓜可以从色泽,根蒂,敲声,纹理,脐部,触感6个特征去判断,每个特征都有2-3个值,用ID3算法生成的决策树如下:
图4 ID3算法生成是否为好瓜的决策树
这里一个节点可以有2个以上的分支,取决于每个特征的所有可能值。这样也使一团杂乱无章的数据有了个很清晰的决策树。

**总结:
ID3算法可以使离散的问题清晰简单化,但也有两点局限:
1. 切分过于迅速
2. 不能直接处理连续型特征**
如遇到连续变化的特征或者特征可能值很多的情况下,算法得出的效果并不理想而且没有多大用处。大多数情况下,生成决策树的目的是用来分类的。

这周,生成决策树的算法是CART算法,不像ID3算法,它是一种二元切分法,具体处理方法:如果特征值大于给定值就走左子树,否则就走右子树。解决了ID3算法的局限,但同时,如果用来分类,生成的决策树容易太贪心,满足了大部分训练数据,出现过拟合。为提高泛化能力,需对其
进行剪枝,把某些节点塌陷成一类。
在本文,构建CART的实现算法有两种(程序在最后)
一种是Peter Harrington的《机器学习实战》的对连续数据的构建算法,核心方法(选取最优特征)的伪代码如下:
*遍历每个特征:
遍历每个特征值:
将数据切分成两份
计算切分的误差
如果当前误差小于当前最小误差:
更新当前最小误差
更新当前最优特征和最优切分点
返回最优切分特征和最优切分点*
一种是李航的《统计学习方法》的用基尼指数构建的算法,程序是自己实现的,目前只能针对离散性数据,核心方法的伪代码如下:
*遍历每个特征:
遍历每个特征值:
将数据切分成两份
计算切分的基尼指数
如果基尼指数小于当前基尼指数:
更新当前基尼指数
更新当前最优特征和最优切分点
返回最优切分特征和最优切分点*
只是把误差计算方式变成了基尼指数,其他基本类似。

对前面两例用CART算法生成的决策树如下:
图5 CART算法生成的贷款决策树
图6 CART算法生成的是否好瓜决策树

图5和图2是一样的,因为用来切分的特征都只有两类
但图6和图4便不一样。

再来对连续的数据构建决策树,数据来自于Peter Harrington的《机器学习实战》的第九章ex0.txt
图7 ex0.txt
肉眼可以分辨,整段数据可分为5段,用CART算法生成的结果如下:

{‘spInd’: 0, ‘spVal’: 0.39434999999999998, ‘left’: {‘spInd’: 0, ‘spVal’: 0.58200200000000002, ‘left’: {‘spInd’: 0, ‘spVal’: 0.79758300000000004, ‘left’: 3.9871631999999999, ‘right’: 2.9836209534883724}, ‘right’: 1.980035071428571}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.19783400000000001, ‘left’: 1.0289583666666666, ‘right’: -0.023838155555555553}}

(实在不想画图了,就用dict表示吧,spInd表示当前分割特征,spVal表示当前分割值,left表示坐子节点,right表示右子节点)
从dict中也明显可以看到,它将数据分成5段,但这个前提是ops=(1,4)选的好,对树进行预剪枝了。

如果ops=(0.1,0.4)会发生什么呢?

{‘spInd’: 0, ‘spVal’: 0.39434999999999998, ‘left’: {‘spInd’: 0, ‘spVal’: 0.58200200000000002, ‘left’: {‘spInd’: 0, ‘spVal’: 0.79758300000000004, ‘left’: {‘spInd’: 0, ‘spVal’: 0.81900600000000001, ‘left’: {‘spInd’: 0, ‘spVal’: 0.83269300000000002, ‘left’: 3.9814298333333347, ‘right’: {‘spInd’: 0, ‘spVal’: 0.81913599999999998, ‘left’: 4.5692899999999996, ‘right’: 4.048082}}, ‘right’: 3.7688410000000001}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.62039299999999997, ‘left’: {‘spInd’: 0, ‘spVal’: 0.62261599999999995, ‘left’: 2.9787170277777779, ‘right’: 2.6702779999999997}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.61605100000000002, ‘left’: 3.5225040000000001, ‘right’: 3.0497069999999997}}}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.48669800000000002, ‘left’: {‘spInd’: 0, ‘spVal’: 0.53324099999999997, ‘left’: {‘spInd’: 0, ‘spVal’: 0.55900899999999998, ‘left’: 2.0720909999999999, ‘right’: 1.8145387500000001}, ‘right’: 2.0843065555555551}, ‘right’: 1.8810897500000001}}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.19783400000000001, ‘left’: {‘spInd’: 0, ‘spVal’: 0.21054200000000001, ‘left’: {‘spInd’: 0, ‘spVal’: 0.37526999999999999, ‘left’: 1.2040690000000001, ‘right’: {‘spInd’: 0, ‘spVal’: 0.316465, ‘left’: 0.86561450000000006, ‘right’: {‘spInd’: 0, ‘spVal’: 0.23417499999999999, ‘left’: 1.1113766363636364, ‘right’: 0.90613224999999997}}}, ‘right’: 1.3753635000000002}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.14865400000000001, ‘left’: 0.071894545454545447, ‘right’: {‘spInd’: 0, ‘spVal’: 0.14314299999999999, ‘left’: -0.27792149999999999, ‘right’: -0.040866062499999994}}}}

显然,过拟合了。生成了很多不必要的节点。在实际应用中,根本不能控制数据值得大小,所以ops也很难选好,而ops的选择对结果的影响很大。所以仅仅预剪枝是远远不够的。

于是需要后剪枝。简单来说,就是选择ops,使得构建出的树足够大,接下来从上而下找到叶节点,用测试集的数据来判断这些叶节点是否能降低测试误差,如果能,就合并,伪代码如下:
*基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算当前两个叶节点合并后的误差
计算合并前的误差
如果合并后的误差小于合并前的误差:
将两个叶节点合并*

对上述决策树进行剪枝,由于没有测试数据,便拿前150当作训练数据,后50当作测试数据,图如下:
图8 ex0.txt训练数据和测试数据

同样,ops=(0.1,0.4),剪枝后的树为:
{‘spInd’: 0, ‘spVal’: 0.39434999999999998, ‘left’: {‘spInd’: 0, ‘spVal’: 0.58028299999999999, ‘left’: {‘spInd’: 0, ‘spVal’: 0.79758300000000004, ‘left’: 3.9739993000000005, ‘right’: 3.0065657575757574}, ‘right’: 1.9667640539772728}, ‘right’: {‘spInd’: 0, ‘spVal’: 0.19783400000000001, ‘left’: 1.0753531944444445, ‘right’: -0.028014558823529413}}

由那么复杂的树剪枝剪成只有五个类别。效果不错

实现代码如下:

treePlotter.py

'''
Created on 2017年7月30日

@author: fujianfei
'''

import matplotlib.pyplot as plt


plt.rcParams['font.sans-serif']=['SimHei']#解约matplotlib画图,中文乱码问题

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值