CART算法之回归树

1.CART算法与决策树算法

CART算法由Breiman等人在 1984 年提出的一种二叉树模型。CART算法仅是二叉树模型。CART算法不仅支持分类树,也支持回归树。

CART回归树有着比较广泛的应用。

2.平方误差的计算

本文介绍CART算法中的回归树。

CART回归树通过计算平方误差来选择分裂特征和确定分裂节点。

对于特征j的某个划分点s,将特征j划分为两个部分:

R_{1}(j,s):x|x^{(j)}\leqslant s,R_{2}(j,s):x|x^{(j)}>s

则其平方误差为:

\sum_{x_{i}\in R_{1}(j,s)}^{}(y_{i}-c_{1})^{2}+\sum_{x_{i}\in R_{2}(j,s)}^{}(y_{i}-c_{2})^{2}

其中,c_{1},c_{2}分别是两个划分区域的预测值。CART算法在计算过程中会得出相应的预测值。

3.CART回归算法过程

CART回归算法在输出回归树的过程中会解决以下3个问题:特征的选择、特征分裂点的选择、分裂区域的预测。

3.1 输入输出

1)输入:训练集D。

2)输出:回归树f(x)

3.2 算法过程

3.2.1 节点预判断

满足以下条件之一的该节点分裂停止:

1)该节点的输出特征的取值都相等;

2)该节点的输入特征的取值均相同;

3)该节点的平方误差小于指定值;

4)其它设定的停止条件(如树的深度、节点的样本数量等)。

否则进入下一步。

3.2.2 选择分裂特征和分裂节点

计算每一个特征所有划分条件下的平方误差,选择使其达到最小值的特征j和对应划分s:

min_{j,s}[min_{c_{1}}\sum_{x_{i}\in R_{1}(j,s)}^{}(y_{i}-c_{1})^{2}+min_{c_{2}}\sum_{x_{i}\in R_{2}(j,s)}^{}(y_{i}-c_{2})^{2}]

对于每一确定的(j,s),要使得\sum_{x_{i}\in R_{m}(j,s)}^{}(y_{i}-c_{m})^{2}最小,不难得出:

c\widehat{}_{m}=\frac{1}{N_{m}}\sum_{x_{i}\in R_{m}(j,s)}^{}y_{i}m=1,2

其中,N_{m}表示对应划分R_{m}(j,s)样本的数量。

如此,可以找到最优的(j,s)

3.2.3 节点输出值

对应的最优分裂特征和分裂点(j,s)的得到的划分,分别以前述c\widehat{}_{m}为对应区域的输出值。

3.2.4 继续分裂

继续分裂,直到每个节点都满足前述分裂停止的条件。

3.2.5 生成决策树

通过上述划分,最终将输入空间划分成M个区域R_{1},R_{2},...,R_{M},得到回归树:

f(x)=\sum_{m=1}^{M}c\widehat{}_{m}I(x\in R_{m})

至此,决策树生成,算法完毕。

4.CART算法演示

4.1 数据集

agelevelincomedebtscore
11813000100050
22014000120066
32233600260060
4252280050070
52834800080
63054200200076
73546000180088
84023500300056
94568000450082
105013900150064

以age,level,income,debt回归score。

4.2 回归过程

计算每一个特征所有切分点的平方误差。

4.2.1 数据排序

在计算具体某一特征的平方误差时,应将该特征按从小到大排序,对应的输出特征的顺序随之变动。如在计算income的平方误差时,将income按从小到大排序,数据展示如下:

incomescore
4280070
1300050
8350056
3360060
10390064
2400066
6420076
5480080
7600088
9800082

4.2.2 平方误差计算

将特征的值排序后,将其切分成两份,计算每一次切分的平方误差。其中切分点的取值为临近两特征的平均值。

如:income的第1个切分点为(2800+3000)/2=2900,其左端平均数为ml=70,右端平均数为mr=(50+56+60+64+66+76+80+88+82)/9=69.1111,其左右平方误差之和为

(70-ml)^2+(50-mr)^2+(56-mr)^2+(60-mr)^2+(64-mr)^2+(66-mr)^2+(76-mr)^2+(80-mr)^2+(88-mr)^2+(82-mr)^2=1344.888889。类似计算其它切分点的平方误差。

按上述计算每一个特征每一个可能切分点的平方误差。

以上计算结果展示如下:

agelevelincomedebt
19369361344.8888891216
21032103211341261.5
3870.095238982.857143870.0952381318.095238
4950.3333331046.3333336521297
51185.6705.6499.21243.2
612733373371342.333333
71330.666667643.428571489.523811301.714286
81309.51105.5721.51345.5
91315.5555561163.5555561163.5555561163.555556

见,其中income和level的第6个切分点的平方误差均最小,为337,不妨选择income作为第一个分裂特征,其分裂点为从小到大排序后的4000、4200之间,取平均值4100。

然后对上述两个分支,再按前述方法分裂。直至满足分裂停止条件。

5.代码实现与可视化

from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import plot_tree
import pandas as pd
import matplotlib.pyplot as plt


df = pd.DataFrame(
    {'age':[18,20,22,25,28,30,35,40,45,50],
    'level':[1,1,3,2,3,5,4,2,6,1],
    'income':[3000,4000,3600,2800,4800,4200,6000,3500,8000,3900],
    'debt':[1000,1200,2600,500,0,2000,1800,3000,4500,1500],
    'score':[50,66,60,70,80,76,88,56,82,64]}
)

#决策树模型
X = df.iloc[:,0:4]
y = df.iloc[:,4]
cart_tree = DecisionTreeRegressor(random_state=0)
cart_tree.fit(X,y)
# 可视化
plt.figure(figsize=(15,9))
plot_tree(cart_tree,filled=True,feature_names=df.columns[0:-1])
plt.show()

6.CART回归树总结

CART回归算法有着非常广泛的应用。

1)CART回归算法也是二叉树。这与ID3算法、C4.5算法均不同,这两个算法均是多叉树。

2)CART回归算法中特征分裂的主要依据是平方误差。ID3算法、C4.5算法分别是信息增益、信息增益比。

3)与ID3、C4.5算法相比,CART回归算法没有对数运算,但却需要对数据进行排序。

4)CART分类算法支持剪枝,采用的是后剪枝。
 

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值