CART决策树算法浅谈(回归树部分)

近日受小C同学的影响,开始慢慢培养写博客的习惯,“开坑之作”(之前那篇请无视~~)就打算谈谈最近研究的CART。鄙人不才,为了写这篇博客参考了不少资料,若写的有不正确的地方,还请各位大牛指正。
套话说完了,正式开始吧。CART全名为分类与回归树,意指该模型可以同时处理分类与回归问题。对于给定的训练数据集,CART通过最小化数据集的GINI系数(分类树)或者基于最小二乘准则最小化输入与输出的总均方误差(回归树)实现机器学习任务,本文首先介绍CART在回归问题中的应用。回归树的生成可分为两步—树的生成和剪枝。

1. 树的生成

对于给定的训练数据集 T={(x1,y1),(x2,y2),...(xN,yN)} ,回归树希望按照某几个特征对数据集进行递归式划分以形成二叉树,使得划分后的数据集叶子结点的输出尽可能接近训练样本的y值。这个过程主要涉及到分裂数据集的特征选择和树的递归生成。就特征选择而言,若设选择数据集T的j号特征某个分量s作为分割的阈值,将数据集分为 R1={x|xjs} , R2={x|xj>s} 两部分,则分割后的数据集与实际y值的均方误差可表示为:

min[minxiRii=1N(yif(xi))2+minxiRii=1N(yif(xi))2]

其中,f(xi)代表模型的输出值,他越接近实际值说明模型精度越高,我们考虑里面的均方误差项 M(f(xi))=minxiRii=1N(yif(xi))2 ,为求得合适的f(xi),求偏导并令其等于0,有
Mf(xi)=2i=1N(yif(xi))=0f(xi)=i=1NyiN

即在单个集合内,最优的f(xi)为集合内数据对应y值的平均值。
回到具体的CART树训练过程中,由于CART的生成是每次基于当前已经分好的数据集,求解最优的分割准则,且每次将数据分为两份,则原始的优化目标变为在每次划分数据集过程中使用如下准则寻找最优的划分
min[minc1i=1N(yic1)2+minc2i=1N(yic2)2]

c1=i=1NR1yiNR1,c2=i=1NR2yiNR2

具体来说,CART将按照每个特征,每个分量将数据集分为大于该分量和小于该分量的部分,并计算对应y值的平均,计算均方误差函数值在其中寻找值最小的那个分量,并将其作为分类准则。重复该寻找流程直到数据集空或者划分前后均方误差下降值小于一定阈值为止(初始阈值可设为inf)。
CART树具体生成步骤如下:
Step1:获得训练数据集,根据数据集第一个特征的第一个分量将数据集分为大于该分量和小于该分量的两个数据集R1和R2,其中 R1={x|Rjs},R2={x|Rj>s}
Step2:根据两个数据集对应y值的平均获得c1和c2,分别计算两个数据集的平均绝对误差(其实就是各个数据集y值的总均方差)。
Step3:重复Step1~2,遍历整个数据集的所有特征所有分量,获得均方误差值的矩阵。
Step4:根据矩阵寻找均方误差最小的分割方案,将该分割点作为树的节点,将分割后的数据集分别赋值给该节点左子树和右子树。
Step5:重复步骤1~4,递归地将数据集分割为更小的部分,直到总均方差的下降值小于某阈值或者数据集中只剩下一类数据为止。
可以看到, CART回归树的生成是一个贪心选择最优分割点的过程,这种贪心策略在一定程度上使得最开始生成的CART树具有很多缺点,这表现在树容易将噪声也拟合进去,出现过拟合,以及训练容易陷入局部最优等。为此,需要在树生成后进行一定的处理,这就是剪枝的目的。

2. 树的剪枝

如果生成的CART树枝条太多,容易把数据集中的一些噪声也拟合进去,这时候就需要减去一些枝条,防止CART树出现过拟合。剪枝又分为预剪枝和后剪枝。后剪枝需要一定数据,因此,实际使用CART树时,常常将训练数据分为训练集和剪枝的数据集。
预剪枝通过调整树停止生长的策略,如提前终止树生长(通过调整均方误差下降的最小值实现)等可实现。这种方法不需要给定数据集,但是受到建模者所给参数的影响太大,有较大弊病。
后剪枝的具体做法是,将数据根据训练好的树模型将数据集递归地分割到叶子结点,然后考虑减去叶子和不剪去叶子两种情况下数据集的均方误差值,如果剪枝使得该值变小,则剪之,否则放弃。遍历所有结点,剪去所有冗余的枝条,就实现了后剪枝。
预剪枝和后剪枝在实际实现CART算法时常常结合使用,最大可能地避免树的过拟合。

3. 基于Python的CART回归树模型实现

基于Python的CART回归树模型实现参考了《机器学习实战》一书。使用numpy和pygraphviz绘图包。pygraphviz安装并不是直接pip install就能搞定的,具体安装步骤参见http://www.cnblogs.com/AimeeKing/p/5021675.html
主要包含了树的创建函数createTree、根据特征的分量将数据集分裂的函数splitData、选择最优分裂点的函数chooseSplit、剪枝函数cutBranches、画树函数drawTree等。
splitData函数根据选定的特征的目标分量thres将数据集分为该特征数值大于thres和小于thres的两部分数据。

def splitData(data, feature, thres):
    mat0 = data[np.nonzero(data[:, feature] > thres)[0], :]
    mat1 = data[np.nonzero(data[:, feature] <= thres)[0], :]
    return mat0, mat1

chooseSplit函数遍历所有特征的所有分量,寻找最合适的数据分裂方案

def chooseSplit(data, ops=1):
    feature = data[:, 0:-1]
    # 获得特征数目以及样本数
    minFunc = []
    oriVal = float(np.var(data[:, -1]) * len(data[:, -1]))
    sampleNum, featureNum = map(int, np.shape(feature))
    if len(np.unique(np.array(feature))) == 1:  # 剩余样本都一样
        return None, np.mean(data[:, -1])  # 既然都一样,拿哪个特征无所谓
    for i in range(featureNum):
        tempFeature = feature[:, i]  # 获得切分数据集
        FeatureFunc = []
        for thres in tempFeature:
            # 拆分标签集
            subMat1, subMat2 = splitData(data, i, thres)
            # 获得标签便于计算目标函数
            y1, y2 = [subMat1[:, -1], subMat2[:, -1]]
            # 处理一下空集的情况
            if len(y2) == 0:
                FeatureFunc.append(np.var(y1) * len(y1))
            elif len(y1) == 0:
                FeatureFunc.append(np.var(y2) * len(y2))
            else:
                FeatureFunc.append(np.var(y1) * len(y1) + np.var(y2) * len(y2))
        minFunc.append(FeatureFunc)
    # 寻找最优分割特征与数值
    locFeature, locVal = np.where(minFunc == np.amin(minFunc))
    # 下降值小于ops,不再生长树
    if (oriVal - np.amin(minFunc)) < ops:
        return None, float(np.var(data[:, -1]) * len(data[:, -1]))
    spVal = float(data[locVal, locFeature[0]])  # 用于分割数据的特征
    return locFeature[0], spVal

splitData函数用于树的创建,首先选择当前数据集的最优分裂方案,如果只有一个数据点就返回,否则创建节点,并分裂数据集,将两个数据集递归地传给节点的左右子树继续分裂。最后返回生成的树。

# 创建树,生成的树叶子结点没有左右子树!!!!
def createTree(data, ops=1):
    # 选中最优的分割
    name, val = chooseSplit(data)
    if name is None:
        return val  # 直接返回
    tree = {'node': name, 'val': val}
    # 递归建树
    ldata, rdata = splitData(data, name, val)
    tree['rchild'] = createTree(rdata, ops)
    tree['lchild'] = createTree(ldata, ops)
    return tree

cutBranches用于树的后剪枝,这里需要一个判断节点是否为叶子结点的函数,由于Python是借助字典实现树结构的,所以可以判断当前结点类型是否为dict来实现。如果左枝条或者右枝条不为叶子结点,则按照给定的树模型分割数据集,递归剪枝过程。当左右节点为叶子结点时,就先按照叶子结点的要求分割一次数据集计算y的均方误差,再计算不剪枝时的均方误差,比较判断是否有必要剪枝。

# 后剪枝
def cutBranches(tree, testData=[]):
    if len(testData) == 0:  # 没有测试数据
        print "没有测试数据,不能剪枝!"
    # 左只或右枝不为叶子,则进行数据分割
    if not (isLeaf(tree['lchild'])) or not (isLeaf(tree['rchild'])):
        ldata, rdata = splitData(testData, tree['node'], tree['val'])
    # 左枝不为叶子
    if not (isLeaf(tree['lchild'])):
        tree['lchild'] = cutBranches(tree['lchild'], ldata)
    # 右枝不为叶子
    if not isLeaf(tree['rchild']):
        tree['rchild'] = cutBranches(tree['rchild'], rdata)
    # 两边都是树叶,开始判断要不要减支
    if isLeaf(tree['lchild']) and isLeaf(tree['rchild']):
        ldata, rdata = splitData(testData, tree['node'], tree['val'])
        # 不进行剪枝的目标函数值
        noMerge = sum(np.power(ldata[:, -1] - tree['lchild'], 2)) + \
                  sum(np.power(rdata[:, -1] - tree['rchild'], 2))
        treeMean = 0.5 * (tree['lchild'] + tree['rchild'])
        Merge = sum(np.power(ldata[:, -1] - treeMean, 2))
        # 判断是否要剪枝
        if Merge < noMerge:
            print "merging..."
            return treeMean  # 返回左右子树平均值实现合并
        else:
            return tree
    else:
        return tree

下图为剪枝之前CART生成的树,可以看到,生成的树臃肿,带有大量的叶子结点。如果不设置ops=1,则该树将会变得更加庞大臃肿,它甚至可能为每一个样本生成一个节点。

下图为经过后剪枝之后的CART树,多余的枝条被减去,整棵树较之前显得更为小巧。说明在一定数据集的支撑下,后剪枝能够起到一定效果。
这里写图片描述
画树函数drawTree则判断节点是否有lchild属性或者rchild属性,如果有,则使用pygraphviz添加从节点到左右子树的节点然后递归调用画树函数,否则只添加有向边。

# 画树
def drawTree(graph, tree):
    # 递归画树
    if tree.has_key('lchild'):
        if not isLeaf(tree['lchild']):
            graph.add_edge(tree['val'], tree['lchild']['val'])
            drawTree(graph, tree['lchild'])
        else:
            graph.add_edge(tree['val'], tree['lchild'])
    if tree.has_key('rchild'):
        if not isLeaf(tree['rchild']):
            graph.add_edge(tree['val'], tree['rchild']['val'])
            drawTree(graph, tree['rchild'])
        else:
            graph.add_edge(tree['val'], tree['rchild'])
    drawTree(newGraph, tree)
    newGraph.layout(prog='dot')
    newGraph.draw('treeP.jpg')

代码和相关数据我传了一份供交流参考。地址:https://github.com/FlyingRoastDuck/CART_REG

参考文献

[1] 李航. 统计学习方法 [M]. 北京:清华大学出版社, 2012: 65-70.
[2] Peter H. 机器学习实战 [M]. 北京:人民邮电出版社,2013: 161-170.

  • 3
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值