如何理解 CART 剪枝 里的 g(t)

CART(Classification And Regression Tree)学习笔记(主要是为了记录剪枝算法中的 g ( t ) g(t) g(t) 如何理解)

在看《统计学习方法》5.5.2 里有关 CART 剪枝的章节,里面有一个在 α ∈ [ α i , α i + 1 ) \alpha \in[\alpha_i, \alpha_{i+1}) α[αi,αi+1) 区间中求最优子树的算法。第一遍看的时候,还真是有点费解。看懂之后,就感觉很清楚了,在此记录一下,后面回忆的时候再看。

CART 算法

为了方便理解,先简单介绍一下分类于回归树(classification and regression tree, CART)。
CART 树是决策树的一种实现,可以用于分类和回归。(决策树本身就是可以用来做分类和回归。)
CART 假设决策树是二叉树。
CART 算法分为两步:

  1. 决策树生成:基于训练数据生成决策树,生成的决策树尽量大(方便后面剪枝之后,有一个足够大的备选集)。
  2. 决策树剪枝:使用损失函数对生成的树进行剪枝。

对于分类和回归两种树,最不好理解的就是如何进行预测计算损失,下面分别简单将一下。

回归树

预测:简单来说,就是把叶结点的所有训练样本的标签的均值,作为预测结果。
具体的,当要预测一个测试样本 x i x_i xi(第 i i i个样本)时,从根节点开始,每个节点会有一个分类器 ( j , s ) (j, s) (j,s),其中 j j j 表示当前节点会对样本中的第 j j j个特征的值进行判断,如果 x i j ≤ s x_i^j \leq s xijs,则该样本进入左侧分支,否则进入右侧分支,直到到达一个叶结点。回归树给每个叶结点都定义了一个值 c m c_m cm,作为落入当前叶结点的样本的预测值。即,
f ( x i ) = c m , x i ∈ R m f(x_i) = c_m, x_i \in R_m f(xi)=cm,xiRm
也可以表示为
f ( x i ) = ∑ c m ∗ I ( x i ∈ R m ) 其 中 I ( x i ∈ R m ) = { 1 , x i ∈ R m 0 , x i ∉ R m f(x_i) = \sum{c_m * I(x_i \in R_m)} \\ 其中\\ I(x_i \in R_m) = \{_{1, x_i \in R_m}^{0, x_i \not\in R_m} f(xi)=cmI(xiRm)I(xiRm)={1,xiRm0,xiRm
下面我们来说一下,为什么 c m = a v g ( y i ∣ y i ∈ R m ) c_m = avg(y_i|y_i \in R_m) cm=avg(yiyiRm)。其中,假设一个树共有 M M M 个叶结点, R m R_m Rm 表示样本落入当前叶结点的样本的集合。
根据回归树的定义:一个回归树对应着输入空间的一个划分及在划分单元的输出值。那么如何选择这个值,才能使回归树的损失最小呢?
让我们假设,一个回归树的形状已经确定了,这时候我们用平方误差来表示训练数据的预测误差 ∑ ( y i − f ( x i ) ) 2 \sum{(y_i - f(x_i))^2} (yif(xi))2。因为对于每个叶结点,里面的样本都是确定的,那么每个叶结点的最小值,也能让整个树的 loss 最小,即求 ∑ y i ∈ R m ( y i − c m ) 2 \sum_{y_i \in R_m}{(y_i - c_m)^2} yiRm(yicm)2 的最小值。根据二元一次方程的极值定理,我们可以算 F ( y i ) = ∑ y i ∈ R m ( y i − c m ) 2 F(y_i) = \sum_{y_i \in R_m}{(y_i - c_m)^2} F(yi)=yiRm(yicm)2 的导数 F ( y i ) ′ = ∑ y i ∈ R m 2 ( y i − c m ) = 0 F(y_i)^{'} = \sum_{y_i \in R_m}{2(y_i - c_m)} = 0 F(yi)=yiRm2(yicm)=0,可以得出 c m = a v g ( y i ∣ y i ∈ R m ) c_m = avg(y_i|y_i \in R_m) cm=avg(yiyiRm)

计算损失:通过上面的公式,我们可以直到,当一个树确定的时候,它每个叶结点的预测值,以及整体的损失都是可以计算出来的。那么,当我们在选在子树的时候,可以把每个特征的每一个出现的值作为划分,计算出当前的损失。选取一个特征的一个具体值 ( j , s ) (j, s) (j,s),使得通过这个分类器划分之后的树,损失最小。这样,就完成了一次分类器的学习。

分类树

预测:落入叶结点的训练样本中,类别个数最多的类,作为预测标签。
损失函数
回归树的损失表示的是预测结果和样本标签的分离程度,那么分类树应该用什么作为损失函数呢?
答案是基尼指数:
G i n i = ∑ k = 1 K p k ∗ ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini = \sum_{k = 1}^{K}p_k * (1 - p_k) = 1 - \sum_{k = 1}^{K}p_k^2 Gini=k=1Kpk(1pk)=1k=1Kpk2
如果落入一个叶结点的样本只有一个类别,那么基尼指数就是 0。
如果落入一个叶结点的类别比较分散,那么当前叶结点的基尼系数也就比较大。

CART 剪枝

树的复杂程度对损失的影响

对两类损失函数有了一定的了解之后,就可以比较好的理解剪枝的过程。
其实,不知道损失函数的具体实现,也不影响对剪枝的理解。

我们要对一个树进行剪枝,那么就需要把树的复杂度加入到我们的损失函数里。
C α ( T ) = C ( T ) + α ( T ) C_\alpha(T) = C(T) + \alpha(T) Cα(T)=C(T)+α(T)
其中 C ( T ) C(T) C(T),可以是回归中的平方差,或者是分类中的基尼指数。这取决于你要建立一个什么类型的树。
T T T 表示任意一个结点,以及从该结点作为根节点的子树。

剪枝前后的损失

下面,我们来看对任意一个结点 t t t 的剪枝前后的两个树:
t t t: 对结点 t t t剪枝之后,只保留结点 t t t ,孩子结点都不要了。所有的叶结点的样本都合并到 t t t 结点。
T t T_t Tt: 剪枝之前,以 t t t 为根结点的完整 CART 决策树。

我们来看一下剪枝前后的损失:
C α ( t ) = C ( t ) + α ∗ 1 , 因 为 此 时 只 有 一 个 结 点 C α ( T t ) = C ( T t ) + α ∗ ∣ T t ∣ C_\alpha(t) = C(t) + \alpha * 1, 因为此时只有一个结点 \\ C_\alpha(T_t) = C(T_t) + \alpha * |T_t| Cα(t)=C(t)+α1,Cα(Tt)=C(Tt)+αTt
α \alpha α 足够小的时候,比如说 0,那么 C α ( T t ) < C α ( t ) C_\alpha(T_t)<C_\alpha(t) Cα(Tt)<Cα(t),因为树越负责,损失也就越小,这个很好理解。
但是当 α \alpha α 越来越大的时候,一定有 α \alpha α 满足 C α ( T t ) = C α ( t ) C_\alpha(T_t)=C_\alpha(t) Cα(Tt)=Cα(t),此时 α = C ( t ) − C ( T t ) ∣ T t ∣ − 1 \alpha=\frac{C(t) - C(T_t)}{|T_t| - 1} α=Tt1C(t)C(Tt)

关键点来了,这时候,每个结点都可以算出一个值 C ( t ) − C ( T t ) ∣ T t ∣ − 1 \frac{C(t) - C(T_t)}{|T_t| - 1} Tt1C(t)C(Tt),我们用 g ( t ) g(t) g(t)来表示这个值。

如果 α > g ( t ) \alpha>g(t) α>g(t)那么,
C α ( T t ) − C α ( t ) = C ( T t ) + α ∗ ∣ T t ∣ − ( C ( t ) + α ∗ 1 ) = C ( T t ) − C ( t ) + α ∗ ( ∣ T t ∣ − 1 ) > C ( T t ) − C ( t ) + C ( t ) − C ( T t ) ∣ T t ∣ − 1 ∗ ( ∣ T t ∣ − 1 ) = 0 C_\alpha(T_t) - C_\alpha(t) = C(T_t) + \alpha * |T_t| - (C(t) + \alpha * 1) \\= C(T_t) - C(t) + \alpha * (|T_t| - 1) \\> C(T_t) - C(t) + \frac{C(t) - C(T_t)}{|T_t| - 1} * (|T_t| - 1) = 0 Cα(Tt)Cα(t)=C(Tt)+αTt(C(t)+α1)=C(Tt)C(t)+α(Tt1)>C(Tt)C(t)+Tt1C(t)C(Tt)(Tt1)=0
C α ( T t ) > C α ( t ) C_\alpha(T_t)>C_\alpha(t) Cα(Tt)>Cα(t)。此时的含义是任何结点 t t t g ( t ) < α g(t) < \alpha g(t)<α,就应该对这个结点进行剪枝,否则整体的损失就会包含 C α ( T t ) C_\alpha(T_t) Cα(Tt),而使整个树的损失变大,大于剪枝后的 C α ( t ) C_\alpha(t) Cα(t)

剪枝流程

剪枝的过程,首先计算出每个结点的 g ( t ) g(t) g(t),按照升序排列, α \alpha α 从零开始,依次取得不同的 g ( t ) g(t) g(t),把 g ( t ) < α g(t)<\alpha g(t)<α的结点都减掉,就得到一颗颗子树 T 0 , T 1 , . . . , T n {T_0, T_1, ..., T_n} T0,T1,...,Tn

参考

《统计学习方法》李航
cart树怎么进行剪枝? - Zergzzlun的回答 - 知乎

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值