【Python机器学习】树回归——将CART算法用于回归

要对数据的复杂关系建模,可以借用树结构来帮助切分数据,如何实现数据的切分?怎样才能知道是否已经充分切分?这些问题的答案取决于叶节点的建模方式。回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括

为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。事实上,在数据集上计算混乱度是非常简单的:首先计算所有数据的均值,然后计算每条数据的值到均值的差值。为了对正负值差同等看待,一般使用绝对值或平方值来代替上述差值。

构建树

构建回归树,需要补充一些新的代码。给定某个误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外,该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。因此,函数只需要完成两件事:用最佳方式切分数据集和生成相应的叶节点。

下面的代码中,chooseBestSplit()最复杂,该函数的目标是找到数据集切分的最佳位置。它遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。该函数的伪代码大致如下:

对每个特征:

    对每个特征值:

        将数据集切分成两份

        计算切分的误差

        如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差

返回最佳切分的特征和阈值

具体实现代码:

def binSplitDataSet(dataSet,feature,value):
    mat0=dataSet[nonzero(dataSet[:,feature]>value)[0],:][0]
    mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0],:][0]
    return mat0,mat1

def regLeaf(dataSet):
    # 负责生成叶节点。当chooseBestSplit()函数确定不再对数据进行切分时,调用本函数来得到叶节点的模型,在回归树中,该模型其实就是目标变量的均值
    return mean(dataSet[:,-1])

def regErr(dataSet):
    #在给定数据上计算目标变量的平方误差。
    return var(dataSet[:,-1])*shape(dataSet)[0]

def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
    #回归树构建的核心函数,目的是找到数据的最佳二元切分方式。
    #如果找不到一个好的二元切分,返回None并同时调用createTree()来产生叶节点,叶节点的值也会返回None
    #tolS和tolN是用户指定的参数,用于控制函数的停止时机。其中tolS是容许的误差下降值,tolN是切分的最小样本数。
    tolS=ops[0]
    tolN=ops[1]
    if len(set(dataSet[:,-1].T.tolist()[0]))==1:
        #如果剩余特征的数目为1,那么就不需要再切分而直接返回
        return None,leafType(dataSet)
    m,n=shape(dataSet)
    S=errType(dataSet)
    bestS=inf
    bestIndex=0
    bestValue=0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0,mat1=binSplitDataSet(dataSet,featIndex,splitVal)
            if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
                continue
            newS=errType(mat0)+errType(mat1)
            if newS<bestS:
                bestIndex=featIndex
                bestValue=splitVal
                bestS=newS
    if (S-bestS)<tolS:
        #如果切分数据集后效果提升不够大,那么就不进行切分操作而直接创建叶节点
        return None,leafType(dataSet)
    mat0,mat1=binSplitDataSet(dataSet,bestIndex,bestValue)
    if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
        return bestIndex,bestValue

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值