《统计学习方法》中CATRT剪枝算法,看一次糊涂一次,在此总结一下。
主要思路:对于原始树T0。第一次剪枝后,得到子树T1。然后从T1中剪枝得到T2,直到只剩一个根结点的子树Tn。于是得到了T0,T1,...,TN一共n+1棵子树。然后再用这n+1棵子树预测独立的验证数据集,谁的误差最小就选谁。
那么问题来了:怎么剪枝,剪哪些节点?
g(t)=(C(t)-C(Tt))/(|Tt|-1)表示简直后整体损失函数减少的程度(联立书中(5.27)和(5.28)可得)。
a>g(t)时,C(t)<C(Tt),则剪枝。a<g(t)时,C(t)>C(Tt)不剪枝。a=g(t)表示剪枝的阈值。
所以每一个节点就对应一个阈值g(t)。
当a从0开始缓慢增大,超过了某个结点的g(t),但还没有超过其他结点的g(t)时会有某棵子树该剪,其他子树不该剪的情况。
这样随着a不断增大,不断地剪枝,就得到了n+1棵子树T0,T1,...,TN。
然后只要用独立数据集测试这n+1棵子树,哪棵子树的误差最小就用哪棵树。