数据挖掘十大算法(十):CART(分类回归树)

本文记录一下关于CART的相关知识其中包括(回归树、树的后剪枝、模型树、树回归模型的预测(树回归模型的评估))。在之前学习完ID3算法有记录一篇相关的学习笔记,所以后面学习CART算法能有一个比较和熟悉的理解。

    贪心算法的决策树,构建算法是ID3,即通过香农熵计算数据的混乱程度,然后求出信息增益,每次选择最大信息增益的划分方式,作为当前的划分方式,直到数据集完成划分,被划分过的特征在之后不会再有任何作用。所以这种划分方式被认为过于迅速,并且处理连续型数据时需要先离散化,这样可能会破坏连续型数据的内在性质。

    另一种切分方式是二元切分法即每次把数据切成两份。如果数据的某特征值等于切分所要求的值,那么这些数据就进入左子树,反之则进入右子树,这就是CART算法的思想。

    CART(分类回归树)算法,该算法既可以用来分类还可以用来回归,所以很值得学习。下面首先使用CART算法构建回归树,并介绍如何为复杂的回归树剪枝(防止过拟合问题)。然后引入一种更高级的方法——模型树。最后对回归树、模型树、线性回归做一个预测(评估)。

    模型树与回归树(在叶子节点使用各自的均值做预测)不同,该算法需要在每个叶子节点构建出一个线性模型

一个核心递归伪代码:

找到最佳的待切分特征: 
    如果该节点不能再分,将该节点存为叶节点 
    执行二元切分 
    在右子树继续调用该函数
    在左子树继续调用该函数

回归树:

说明:创建树函数creatTree()的两个参数默认值为 回归树的叶子节点创建函数、误差计算函数,所以这决定了如果使用默认值,则创建的是回归树。后面我们需要构建模型树,只需要改为传入模型树的两个函数参数即可。参数ops为预剪枝方法,该参数的设置决定了树构建的大小。

样例数据(来自第九章):

from numpy import *
import matplotlib.pyplot as plt

# 读取本地文件,python3 list(map)
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))  # python3问题修改
        dataMat.append(fltLine)
    return dataMat

# 根据特征值划分数据集,得到两个数据集
def binSplitDataSet(dataSet, feature, value):   # nonzero返回真(True)值的下标
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] # 取该列某值大于特征值的行
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]   # python3问题修改
    return mat0,mat1

# 返回一个值,生成叶子节点(目标变量均值)
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

# 误差计算函数 返回方差总和
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]   # var方差计算函数

# 选择最佳的特征、特征值   (一旦不满足划分的条件便返回叶子节点)
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:  # 特征值唯一,返回None和叶子节点
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)    # 获得数据集的混乱程度误差,后面求混乱程度减少了多少
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):   # python3问题修改
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)      # 二分 划分数据集
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若划分效果不好(数据集太小),继续划分
            newS = errType(mat0) + errType(mat1)    # 两个数据集的混乱程度求和,与bestS相比较
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果混乱程度减少不大,则返回叶子节点
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) # 根据最佳的特征、特征值来二分划分数据
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 若划分效果不好(数据集太小),返回叶子节点
        return None, leafType(dataSet)
    return bestIndex,bestValue # 返回最佳特征、特征值

# 数据集、创建叶子节点、误差计算函数、(1:最小的误差下降阈值 4:切分的最少样本数要求)
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops) # 选择最佳特征、特征值
    if feat == None: return val  # 若特征为None,返回叶子节点值
    retTree = {}    # 创建字典,用于保存树节点的信息
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)    # 根据已经划分返回的特征、特征值继续划分数据集
    retTree['left'] = createTree(lSet, leafType, errType, ops)  # 这两个函数为递归,直到叶子节点
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

myDat2 = loadDataSet('ex00.txt')
myMat2 = mat(myDat2)
result = createTree(myMat2)
print(result)

本段代码在书中有几处错误,我找到了两处,另一处参考了一篇博客:

1、TypeError: unsupported operand type(s) for /: ‘map‘ and ‘int‘ 
    修改loadDataSet函数某行为fltLine = list(map(float,curLine)),因为python3中map的返回值变了,所以要加list() 
2、TypeError: unhashable type: ‘matrix’ 
    修改chooseBestSplit函数某行为:for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]): matrix类型不能被hash。 
3、TypeError: index 0 is out of bounds 
    函数修改两行binSplitDataSet 
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] 
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]

上面的代码不是特别的复杂,核心思想就是通过二元切分法,用目前最佳的方式对数据进行切分。

看一下数据集的图型:

import matplotlib.pyplot as plt
myDat=loadDataSet('ex00.txt')
myMat=mat(myDat)
plt.plot(myMat[:,0],myMat[:,1],'ro')
plt.show()

树的剪枝:

    前面我们提过ops参数的设置,可以决定我们树的构建大小,可能过拟合也可能欠拟合。该参数的设置会对我们树在构建过程中就进行剪枝操作,所以这是一种预剪枝操作。下面介绍一下另一种后剪枝操作,一般需要将两种剪枝操作同时使用,能达到更好的剪枝效果。

    后剪枝操作:后剪枝需要将数据分为训练集、测试集,首先给定参数,构建足够复杂的树,然后从上而下找到叶子节点,用测试集来判断将这些叶子节点合并能否降低测试误差,如果可以则合并。

# 后剪枝操作
# 判断该节点是否为子节点(字典 True)
def isTree(obj):
    return (type(obj).__name__ == 'dict')

# 递归 从上到下遍历直到两个叶子节点计算它们的平均值(塌陷处理)
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):                        # 塌陷处理 简单描述就是从最下面的叶子节点(通过某种计算方式)开始两两合并
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

# 修剪过程主函数
def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree)  # 如果没有测试集,塌陷处理(及getMean函数)
    if (isTree(tree['right']) or isTree(tree['left'])):  # 如果有树,则根据树的信息划分测试集
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)   # 测试集非空,有树,则继续prune递归对测试集进行切分
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    # 如果它们现在都是叶子,看看是否可以合并它们
    if not isTree(tree['left']) and not isTree(tree['right']):
        # 划分测试集,计算划分后的误差与划分前的误差,两者比较,若划分更好则合并操作,否则不合并直接返回
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))   # 未合并误差
        treeMean = (tree['left'] + tree['right']) / 2.0   # 合并及将两个叶子节点的值求均值
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))  # 合并后误差
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean # 合并便返回两者的均值
        else:
            return tree
    else:
        return tree

# 获得数据
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
# 创建尽可能大的树(0,1)
myTree = createTree(myMat2,ops=(0,1))
myDatTest = loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
# 剪枝过程
result = prune(myTree,myMat2Test)
print(result)

这里是剪枝函数,需要调用上面的树模型构建函数。

剪枝过程的判断条件、递归有点多,需要仔细的理解,当然一步一步来都是比较容易理解的。下面介绍更高级的模型树。

模型树:

    上面提到过,该方法与回归树不同的地方是:该算法需要在每个叶子节点构建出一个线性模型,取代回归树的均值表示法。

树模型的叶子节点可以是一个常数,当然也可以是分段的线性函数,下面来看一个图就明白:

模型树的可解释性优于回归树,同时具有更高的预测准确度。如图中如果我们使用分段的线性函数肯定比一组常数拟合的效果好,而分段点大概在0.3左右,等下我们构建出模型树后,便可以得到该分段点和分段函数了。

由于模型树的构建与就回归树大致相同,只是叶子节点的创建函数leafType()、误差计算函数errType()需要重新定义,以及传参改变原来的默认参数。下面为模型树的主要函数:

# 模型树构建
# 这里的模型树使用到了上面回归树的函数createTree(),该函数只需改变两个固定参数(子节点生成函数 误差计算函数)
# 便可以在回归树与模型树之间切换

# 获得线性回归系数  与线性回归那里一样
def linearSolve(dataSet):
    m,n = shape(dataSet)
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X        # 线性回归公式代入
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y # 返回回归系数 数据 目标值

# 创建叶子节点(即回归系数)
def modelLeaf(dataSet):#create linear model and return coeficients
    ws,X,Y = linearSolve(dataSet)
    return ws

# 预测目标值 用于与Y求平方误差和
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

myMat = mat(loadDataSet('exp2.txt'))
myTree = createTree(myMat,modelLeaf,modelErr,(1,10)) # 传入中间两个用于构建模型树的函数参数
print(myTree)

同过模型树返回划分的信息我们来看看它线性回归的拟合线如何:

    y = kx + b           左值为b,右值为k,带入横坐标x得到y   y = 3.46+1.185x  y = 12x

可以看到,分段的拟合线,很不错,也更加的直观。

树回归于标准回归的评估(预测):

# (树回归模型与标准回归的预测)树回归模型与标准回归的评估

# 回归树的值计算函数
def regTreeEval(model, inDat):
    return float(model)     # 不是树,则为叶节点(值)

# 模型树的值计算函数
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]     # 模型树通过线性回归系数来计算预测值
    X = mat(ones((1, n + 1)))   # 第一个值为1
    X[:, 1:n + 1] = inDat
    return float(X * model) # 测试数据向量*回归系数向量 得到预测值

# 预测 (通过树模型预测当前值)
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree):    # 如果不是一颗树,则为叶节点(数值)
        return modelEval(tree, inData)
    # 根据树来查询当前测试数据位置
    if inData[tree['spInd']] > tree['spVal']:   #tree['spInd'] 本次树的划分特征点   inData[tree['spInd']] 该特征值
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

# 预测 (循环所有测试集)
def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))

myTree1 = createTree(trainMat,ops=(1,20))       # 创建回归树
yHat = createForeCast(myTree1,testMat[:,0])     # 预测
result1 = corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] # 该函数计算预测值与真实值的相关系数
print(result1)

myTree2 = createTree(trainMat,modelLeaf,modelErr,ops=(1,20))    # 创建模型树
yHat = createForeCast(myTree2,testMat[:,0],modelTreeEval)
result2 = corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]
# print(myTree2)
print(result2)

ws,X,y = linearSolve(trainMat)    # 标准回归
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i,0]*ws[1,0]+ws[0,0]
result3 = corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]
print(result3)

从结果可以看到模型树的效果最好,回归树其次,标准回归的效果最差。

以上是所有内容,通过实践可以看到CART算法,相对于ID3算法确实有很大的优势。尤其是对分类和回归通吃更是让人欲罢不能,CART可以用于构建二元树并处理离散型或连续型数据的切分,使用不同的误差准则、叶节点创建,我们可以构建回归树和模型树。

 

参考书籍:《机器学习实战》

参考博客:https://blog.csdn.net/sinat_17196995/article/details/69621687    某条代码错误参考

  • 8
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值