1.CART算法与决策树算法
CART算法由Breiman等人在 1984 年提出的一种二叉树模型。CART算法仅是二叉树模型。CART算法不仅支持分类树,也支持回归树。
CART回归树有着比较广泛的应用。
2.平方误差的计算
本文介绍CART算法中的回归树。
CART回归树通过计算平方误差来选择分裂特征和确定分裂节点。
对于特征的某个划分点
,将特征
划分为两个部分:
则其平方误差为:
其中,分别是两个划分区域的预测值。CART算法在计算过程中会得出相应的预测值。
3.CART回归算法过程
CART回归算法在输出回归树的过程中会解决以下3个问题:特征的选择、特征分裂点的选择、分裂区域的预测。
3.1 输入输出
1)输入:训练集D。
2)输出:回归树。
3.2 算法过程
3.2.1 节点预判断
满足以下条件之一的该节点分裂停止:
1)该节点的输出特征的取值都相等;
2)该节点的输入特征的取值均相同;
3)该节点的平方误差小于指定值;
4)其它设定的停止条件(如树的深度、节点的样本数量等)。
否则进入下一步。
3.2.2 选择分裂特征和分裂节点
计算每一个特征所有划分条件下的平方误差,选择使其达到最小值的特征和对应划分
:
对于每一确定的,要使得
最小,不难得出:
,
其中,表示对应划分
样本的数量。
如此,可以找到最优的
3.2.3 节点输出值
对应的最优分裂特征和分裂点的得到的划分,分别以前述
为对应区域的输出值。
3.2.4 继续分裂
继续分裂,直到每个节点都满足前述分裂停止的条件。
3.2.5 生成决策树
通过上述划分,最终将输入空间划分成个区域
,得到回归树:
至此,决策树生成,算法完毕。
4.CART算法演示
4.1 数据集
age | level | income | debt | score | |
1 | 18 | 1 | 3000 | 1000 | 50 |
2 | 20 | 1 | 4000 | 1200 | 66 |
3 | 22 | 3 | 3600 | 2600 | 60 |
4 | 25 | 2 | 2800 | 500 | 70 |
5 | 28 | 3 | 4800 | 0 | 80 |
6 | 30 | 5 | 4200 | 2000 | 76 |
7 | 35 | 4 | 6000 | 1800 | 88 |
8 | 40 | 2 | 3500 | 3000 | 56 |
9 | 45 | 6 | 8000 | 4500 | 82 |
10 | 50 | 1 | 3900 | 1500 | 64 |
以age,level,income,debt回归score。
4.2 回归过程
计算每一个特征所有切分点的平方误差。
4.2.1 数据排序
在计算具体某一特征的平方误差时,应将该特征按从小到大排序,对应的输出特征的顺序随之变动。如在计算income的平方误差时,将income按从小到大排序,数据展示如下:
income | score | |
4 | 2800 | 70 |
1 | 3000 | 50 |
8 | 3500 | 56 |
3 | 3600 | 60 |
10 | 3900 | 64 |
2 | 4000 | 66 |
6 | 4200 | 76 |
5 | 4800 | 80 |
7 | 6000 | 88 |
9 | 8000 | 82 |
4.2.2 平方误差计算
将特征的值排序后,将其切分成两份,计算每一次切分的平方误差。其中切分点的取值为临近两特征的平均值。
如:income的第1个切分点为(2800+3000)/2=2900,其左端平均数为70,右端平均数为
(50+56+60+64+66+76+80+88+82)/9=69.1111,其左右平方误差之和为
。类似计算其它切分点的平方误差。
按上述计算每一个特征每一个可能切分点的平方误差。
以上计算结果展示如下:
age | level | income | debt | |
1 | 936 | 936 | 1344.888889 | 1216 |
2 | 1032 | 1032 | 1134 | 1261.5 |
3 | 870.095238 | 982.857143 | 870.095238 | 1318.095238 |
4 | 950.333333 | 1046.333333 | 652 | 1297 |
5 | 1185.6 | 705.6 | 499.2 | 1243.2 |
6 | 1273 | 337 | 337 | 1342.333333 |
7 | 1330.666667 | 643.428571 | 489.52381 | 1301.714286 |
8 | 1309.5 | 1105.5 | 721.5 | 1345.5 |
9 | 1315.555556 | 1163.555556 | 1163.555556 | 1163.555556 |
可见,其中income和level的第6个切分点的平方误差均最小,为337,不妨选择income作为第一个分裂特征,其分裂点为从小到大排序后的4000、4200之间,取平均值4100。
然后对上述两个分支,再按前述方法分裂。直至满足分裂停止条件。
5.代码实现与可视化
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import plot_tree
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(
{'age':[18,20,22,25,28,30,35,40,45,50],
'level':[1,1,3,2,3,5,4,2,6,1],
'income':[3000,4000,3600,2800,4800,4200,6000,3500,8000,3900],
'debt':[1000,1200,2600,500,0,2000,1800,3000,4500,1500],
'score':[50,66,60,70,80,76,88,56,82,64]}
)
#决策树模型
X = df.iloc[:,0:4]
y = df.iloc[:,4]
cart_tree = DecisionTreeRegressor(random_state=0)
cart_tree.fit(X,y)
# 可视化
plt.figure(figsize=(15,9))
plot_tree(cart_tree,filled=True,feature_names=df.columns[0:-1])
plt.show()
6.CART回归树总结
CART回归算法有着非常广泛的应用。
1)CART回归算法也是二叉树。这与ID3算法、C4.5算法均不同,这两个算法均是多叉树。
2)CART回归算法中特征分裂的主要依据是平方误差。ID3算法、C4.5算法分别是信息增益、信息增益比。
3)与ID3、C4.5算法相比,CART回归算法没有对数运算,但却需要对数据进行排序。
4)CART分类算法支持剪枝,采用的是后剪枝。