CART(Classification And Regression Tree)学习笔记(主要是为了记录剪枝算法中的 g ( t ) g(t) g(t) 如何理解)
在看《统计学习方法》5.5.2 里有关 CART 剪枝的章节,里面有一个在
α
∈
[
α
i
,
α
i
+
1
)
\alpha \in[\alpha_i, \alpha_{i+1})
α∈[αi,αi+1) 区间中求最优子树的算法。第一遍看的时候,还真是有点费解。看懂之后,就感觉很清楚了,在此记录一下,后面回忆的时候再看。
CART 算法
为了方便理解,先简单介绍一下分类于回归树(classification and regression tree, CART)。
CART 树是决策树的一种实现,可以用于分类和回归。(决策树本身就是可以用来做分类和回归。)
CART 假设决策树是二叉树。
CART 算法分为两步:
- 决策树生成:基于训练数据生成决策树,生成的决策树尽量大(方便后面剪枝之后,有一个足够大的备选集)。
- 决策树剪枝:使用损失函数对生成的树进行剪枝。
对于分类和回归两种树,最不好理解的就是如何进行预测和计算损失,下面分别简单将一下。
回归树
预测:简单来说,就是把叶结点的所有训练样本的标签的均值,作为预测结果。
具体的,当要预测一个测试样本
x
i
x_i
xi(第
i
i
i个样本)时,从根节点开始,每个节点会有一个分类器
(
j
,
s
)
(j, s)
(j,s),其中
j
j
j 表示当前节点会对样本中的第
j
j
j个特征的值进行判断,如果
x
i
j
≤
s
x_i^j \leq s
xij≤s,则该样本进入左侧分支,否则进入右侧分支,直到到达一个叶结点。回归树给每个叶结点都定义了一个值
c
m
c_m
cm,作为落入当前叶结点的样本的预测值。即,
f
(
x
i
)
=
c
m
,
x
i
∈
R
m
f(x_i) = c_m, x_i \in R_m
f(xi)=cm,xi∈Rm
也可以表示为
f
(
x
i
)
=
∑
c
m
∗
I
(
x
i
∈
R
m
)
其
中
I
(
x
i
∈
R
m
)
=
{
1
,
x
i
∈
R
m
0
,
x
i
∉
R
m
f(x_i) = \sum{c_m * I(x_i \in R_m)} \\ 其中\\ I(x_i \in R_m) = \{_{1, x_i \in R_m}^{0, x_i \not\in R_m}
f(xi)=∑cm∗I(xi∈Rm)其中I(xi∈Rm)={1,xi∈Rm0,xi∈Rm
下面我们来说一下,为什么
c
m
=
a
v
g
(
y
i
∣
y
i
∈
R
m
)
c_m = avg(y_i|y_i \in R_m)
cm=avg(yi∣yi∈Rm)。其中,假设一个树共有
M
M
M 个叶结点,
R
m
R_m
Rm 表示样本落入当前叶结点的样本的集合。
根据回归树的定义:一个回归树对应着输入空间的一个划分及在划分单元的输出值。那么如何选择这个值,才能使回归树的损失最小呢?
让我们假设,一个回归树的形状已经确定了,这时候我们用平方误差来表示训练数据的预测误差
∑
(
y
i
−
f
(
x
i
)
)
2
\sum{(y_i - f(x_i))^2}
∑(yi−f(xi))2。因为对于每个叶结点,里面的样本都是确定的,那么每个叶结点的最小值,也能让整个树的 loss 最小,即求
∑
y
i
∈
R
m
(
y
i
−
c
m
)
2
\sum_{y_i \in R_m}{(y_i - c_m)^2}
∑yi∈Rm(yi−cm)2 的最小值。根据二元一次方程的极值定理,我们可以算
F
(
y
i
)
=
∑
y
i
∈
R
m
(
y
i
−
c
m
)
2
F(y_i) = \sum_{y_i \in R_m}{(y_i - c_m)^2}
F(yi)=∑yi∈Rm(yi−cm)2 的导数
F
(
y
i
)
′
=
∑
y
i
∈
R
m
2
(
y
i
−
c
m
)
=
0
F(y_i)^{'} = \sum_{y_i \in R_m}{2(y_i - c_m)} = 0
F(yi)′=∑yi∈Rm2(yi−cm)=0,可以得出
c
m
=
a
v
g
(
y
i
∣
y
i
∈
R
m
)
c_m = avg(y_i|y_i \in R_m)
cm=avg(yi∣yi∈Rm)。
计算损失:通过上面的公式,我们可以直到,当一个树确定的时候,它每个叶结点的预测值,以及整体的损失都是可以计算出来的。那么,当我们在选在子树的时候,可以把每个特征的每一个出现的值作为划分,计算出当前的损失。选取一个特征的一个具体值 ( j , s ) (j, s) (j,s),使得通过这个分类器划分之后的树,损失最小。这样,就完成了一次分类器的学习。
分类树
预测:落入叶结点的训练样本中,类别个数最多的类,作为预测标签。
损失函数:
回归树的损失表示的是预测结果和样本标签的分离程度,那么分类树应该用什么作为损失函数呢?
答案是基尼指数:
G
i
n
i
=
∑
k
=
1
K
p
k
∗
(
1
−
p
k
)
=
1
−
∑
k
=
1
K
p
k
2
Gini = \sum_{k = 1}^{K}p_k * (1 - p_k) = 1 - \sum_{k = 1}^{K}p_k^2
Gini=k=1∑Kpk∗(1−pk)=1−k=1∑Kpk2
如果落入一个叶结点的样本只有一个类别,那么基尼指数就是 0。
如果落入一个叶结点的类别比较分散,那么当前叶结点的基尼系数也就比较大。
CART 剪枝
树的复杂程度对损失的影响
对两类损失函数有了一定的了解之后,就可以比较好的理解剪枝的过程。
其实,不知道损失函数的具体实现,也不影响对剪枝的理解。
我们要对一个树进行剪枝,那么就需要把树的复杂度加入到我们的损失函数里。
C
α
(
T
)
=
C
(
T
)
+
α
(
T
)
C_\alpha(T) = C(T) + \alpha(T)
Cα(T)=C(T)+α(T)
其中
C
(
T
)
C(T)
C(T),可以是回归中的平方差,或者是分类中的基尼指数。这取决于你要建立一个什么类型的树。
T
T
T 表示任意一个结点,以及从该结点作为根节点的子树。
剪枝前后的损失
下面,我们来看对任意一个结点
t
t
t 的剪枝前后的两个树:
t
t
t: 对结点
t
t
t剪枝之后,只保留结点
t
t
t ,孩子结点都不要了。所有的叶结点的样本都合并到
t
t
t 结点。
T
t
T_t
Tt: 剪枝之前,以
t
t
t 为根结点的完整 CART 决策树。
我们来看一下剪枝前后的损失:
C
α
(
t
)
=
C
(
t
)
+
α
∗
1
,
因
为
此
时
只
有
一
个
结
点
C
α
(
T
t
)
=
C
(
T
t
)
+
α
∗
∣
T
t
∣
C_\alpha(t) = C(t) + \alpha * 1, 因为此时只有一个结点 \\ C_\alpha(T_t) = C(T_t) + \alpha * |T_t|
Cα(t)=C(t)+α∗1,因为此时只有一个结点Cα(Tt)=C(Tt)+α∗∣Tt∣
当
α
\alpha
α 足够小的时候,比如说 0,那么
C
α
(
T
t
)
<
C
α
(
t
)
C_\alpha(T_t)<C_\alpha(t)
Cα(Tt)<Cα(t),因为树越负责,损失也就越小,这个很好理解。
但是当
α
\alpha
α 越来越大的时候,一定有
α
\alpha
α 满足
C
α
(
T
t
)
=
C
α
(
t
)
C_\alpha(T_t)=C_\alpha(t)
Cα(Tt)=Cα(t),此时
α
=
C
(
t
)
−
C
(
T
t
)
∣
T
t
∣
−
1
\alpha=\frac{C(t) - C(T_t)}{|T_t| - 1}
α=∣Tt∣−1C(t)−C(Tt)。
关键点来了,这时候,每个结点都可以算出一个值 C ( t ) − C ( T t ) ∣ T t ∣ − 1 \frac{C(t) - C(T_t)}{|T_t| - 1} ∣Tt∣−1C(t)−C(Tt),我们用 g ( t ) g(t) g(t)来表示这个值。
如果
α
>
g
(
t
)
\alpha>g(t)
α>g(t)那么,
C
α
(
T
t
)
−
C
α
(
t
)
=
C
(
T
t
)
+
α
∗
∣
T
t
∣
−
(
C
(
t
)
+
α
∗
1
)
=
C
(
T
t
)
−
C
(
t
)
+
α
∗
(
∣
T
t
∣
−
1
)
>
C
(
T
t
)
−
C
(
t
)
+
C
(
t
)
−
C
(
T
t
)
∣
T
t
∣
−
1
∗
(
∣
T
t
∣
−
1
)
=
0
C_\alpha(T_t) - C_\alpha(t) = C(T_t) + \alpha * |T_t| - (C(t) + \alpha * 1) \\= C(T_t) - C(t) + \alpha * (|T_t| - 1) \\> C(T_t) - C(t) + \frac{C(t) - C(T_t)}{|T_t| - 1} * (|T_t| - 1) = 0
Cα(Tt)−Cα(t)=C(Tt)+α∗∣Tt∣−(C(t)+α∗1)=C(Tt)−C(t)+α∗(∣Tt∣−1)>C(Tt)−C(t)+∣Tt∣−1C(t)−C(Tt)∗(∣Tt∣−1)=0
即
C
α
(
T
t
)
>
C
α
(
t
)
C_\alpha(T_t)>C_\alpha(t)
Cα(Tt)>Cα(t)。此时的含义是任何结点
t
t
t的
g
(
t
)
<
α
g(t) < \alpha
g(t)<α,就应该对这个结点进行剪枝,否则整体的损失就会包含
C
α
(
T
t
)
C_\alpha(T_t)
Cα(Tt),而使整个树的损失变大,大于剪枝后的
C
α
(
t
)
C_\alpha(t)
Cα(t)。
剪枝流程
剪枝的过程,首先计算出每个结点的 g ( t ) g(t) g(t),按照升序排列, α \alpha α 从零开始,依次取得不同的 g ( t ) g(t) g(t),把 g ( t ) < α g(t)<\alpha g(t)<α的结点都减掉,就得到一颗颗子树 T 0 , T 1 , . . . , T n {T_0, T_1, ..., T_n} T0,T1,...,Tn
参考
《统计学习方法》李航
cart树怎么进行剪枝? - Zergzzlun的回答 - 知乎