Machine Learning In Action - Chapter 9 Tree-based regression

Chapter9 - Tree-based regression

CART是classification and regression tree,分类与回归树,正如名字所说的,它其实有两种树,分类树和回归树。第三章中讲的决策树是ID3决策树,根据信息增益作为特征选择算法。

CART树与前面说的树有什么差别呢?

1.之前的生成树的算法在对某个特征切分的时候,将数据集按这个特征的所有取值分成很多部分,这样的切分速度太快,而CART只进行二元切分,对每个特征只切分成两个部分。

2.ID3和C4.5只能处理离散型变量,而CART因为是二元切分,可以处理连续型变量,而且只要将特征选择的算法改一下的话既可以生成回归树。

本章讲了回归树和模型树

回归树
  • 特征选择

    回归树使用的是平方误差最小法作为特征选择算法。

    其思想是将当前节点的数据按照某个特征在某个切分点分成两类,比如 R1,R2 ,其对应的类别为 C1,C2 ,我们的任务就是找到一个切分点使误差最小,那么怎么度量误差呢?这里使用的是平方误差,即

    min[minxiR1(yic1)2+minxiR2(yic2)2]

    遍历某个特征可取的s个切分点(对离散型变量,要么等于要么不等于;对连续型变量,<或者>=),选择使上式最小的切分点。

    对每个确定的集合,c1,c2取平均值 xiR1(yic1)2 xiR2(yic2)2 才会最小,这样的话就是求划分为两个集合后,分别对每个集合求方差*实例数,加起来的最小值。

  • 剪枝

    简单的剪枝,如果merge后的误差更小就merge

python实现

# 选取最佳分裂特征和值
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    # 全属于一类
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        # print array(dataSet[:,featIndex].T).tolist()
        for splitVal in set(array(dataSet[:,featIndex].T)[0].tolist()):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 
                continue
            # 平方误差
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 当新分裂的误差与为分裂的误差小于一个阈值,就不分裂
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex,bestValue

# 创建树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

#简单的剪枝,如果merge后的误差更小就merge
def prune(tree, testData):
    if shape(testData)[0] == 0: 
        return getMean(tree)
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
    if isTree(tree['left']): 
        tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): 
        tree['right'] = prune(tree['right'], rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
                sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge:
            print "merging"
            return treeMean
        else: 
            return tree
    else: return tree
模型树

树的叶子节点不是一个数值,而是一个模型的参数,如果叶子节点是线性回归模型,那么叶子节点存的就是权值系数w

python实现

# 叶子节点存放的东西
def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

#线性模型的误差
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))

createTree(myMat2, modelLeaf, modelErr,(1,10))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值