一棵完全生长的决策树会面临一个很严重的问题,即过拟合。当模型过拟合进行预测时,在测试集上的效果将会很差。因此我们需要对决策树进行剪枝, 剪掉一些枝叶,提升模型的泛化能力。
决策树的剪枝通常有两种方法,预剪枝( Pre-Pruning )和后剪枝( Post-Pruning )。
文章目录
一、预剪枝
1.什么是预剪枝
预剪枝 , 即在生成决策树的过程中提前停止树的增长。核心思想是在树中结点进行扩展之前,先计算当前的划分是否能带来模型泛化能力的提升,如果不能,则不再继续生长子树。此时可能存在不同类别的样本同时存于结点中,按照多数投票的原则判断该结点所属类别。预剪枝对于何时停止决策树的生长有以下几种方法:
( 1 )当树到达一定深度的时候,停止树的生长。
( 2 )当到达当前结点的样本数量小于某个阈值的时候,停止树的生长。
( 3 )计算每次分裂对测试集的准确度提升,当小于某个阈值的时候 ,不再继续扩展。
2.预剪枝的优缺点
预剪枝具有思想直接、算法简单、效率高等特点,适合解决大规模问题。 但如何准确地估计何时停止树的生长(即上述方法中的深度或阈值),针对不同问题会有很大差别,需要一定经验判断。且预剪枝存在一定局限性,高欠拟合的风险,虽然当前的划分会导致测试集准确率降低 , 但在之后的划分中,准确率可能会高显著上升。
二、后剪枝
1.什么是后剪枝
后剪枝,是在已经生成的过拟合决策树上进行剪枝,得到简化版的剪枝决策树。核心思想是让算法生成一棵完全生长的决策树,然后从最底层向上计算是否剪枝。剪枝过程将子树删除,用一个叶子结点替代,该结点的类别同样按照多数投票的原则进行判断。 同样地 ,后剪枝也可以通过在测试集上的准确率进行判断,如果剪枝过后准确率有所提升,则进行剪枝。 相比于预剪枝,后剪枝方法通常可以得到泛化能力更强的决策树,但时间开销会更大。
常见的后剪枝方法包括错误率降低剪枝( Reduced Error Pruning, REP )、悲观剪枝( Pessimistic Error Pruning, PEP )、代价复杂度剪枝( Cost Complexity Pruning, CCP )、最小误差剪枝( Minimum Error Pruning, MEP )、 CVP ( Critical Value Pruning )、 OPP ( Optimal Pruning )等方法,这些剪枝方法各有利弊,关注不同的优化角度。
2.后剪枝的优缺点
-后剪枝比预剪枝保留了更多的分支,欠拟合风险小,泛化性能往往优于预剪枝决策树
-训练时间开销大:后剪枝过程是在生成完全决策树之后进行的,需要自底向上对所有非叶结点逐- -计算
三、代码
1、导入数据
import math
import numpy as np
# 创建西瓜书数据集2.0
def createDataXG20():
data = np.array([['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
, ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
, ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
, ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑']
, ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑']
, ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
, ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘']
, ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑']
, ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑']
, ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘']
, ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑']
, ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘']
, ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑']
, ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑']
, ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘']
, ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑']
, ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑']])
label = np.array(['是', '是', '是', '是', '是', '是', '是', '是', '否', '否', '否', '否', '否', '否', '否', '否', '否'])
name = np.array(['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'])
return data, label, name
def splitXgData20(xgData, xgLabel):
xgDataTrain = xgData[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16],:]
xgDataTest = xgData[[3, 4, 7, 8, 10, 11, 12],:]
xgLabelTrain = xgLabel[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16]]
xgLabelTest = xgLabel[[3, 4, 7, 8, 10, 11, 12]]
return xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest