树回归

主要内容:CART算法、回归与模型树、树减枝算法

并在最后进行了 回归树、模型树以及标准回归之间的比较

CART算法:

CART_regression.py

'''
Created on 2018年8月1日

@author: hcl
'''
from numpy import *

def loadDataSet(fileName):
    '''
    读取一个一tab键为分隔符的文件,然后将每行的内容保存成一组浮点数    
    '''
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))
        dataMat.append(fltLine)
    return dataMat


def regLeaf(dataSet):
    '''负责生成叶节点'''
    #当chooseBestSplit()函数确定不再对数据进行切分时,将调用本函数来得到叶节点的模型。
    #在回归树中,该模型其实就是目标变量的均值。
    return mean(dataSet[:,-1])

def regErr(dataSet):
    '''
    误差估计函数,该函数在给定的数据上计算目标变量的平方误差,这里直接调用均方差函数
    '''
    return var(dataSet[:,-1]) * shape(dataSet)[0]#返回总方差

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''
    用最佳方式切分数据集和生成相应的叶节点
    '''  
    #ops为用户指定参数,用于控制函数的停止时机
    tolS = ops[0]; tolN = ops[1]
    #如果所有值相等则退出
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    #在所有可能的特征及其可能取值上遍历,找到最佳的切分方式
    #最佳切分也就是使得切分后能达到最低误差的切分
    for featIndex in range(n-1):
#        for splitVal in set(dataSet[:,featIndex]):
        for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): 
                continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #如果误差减小不大则退出
    if (S - bestS) < tolS: 
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    #如果切分出的数据集很小则退出
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    #提前终止条件都不满足,返回切分特征和特征值
    return bestIndex,bestValue

def binSplitDataSet(dataSet, feature, value):
    '''
    数据集切分函数    
    '''
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''
    树构建函数
    leafType:建立叶节点的函数
    errType:误差计算函数
    ops:包含树构建所需其他参数的元组    
    '''    
    #选择最优的划分特征
    #如果满足停止条件,将返回None和某类模型的值
    #若构建的是回归树,该模型是一个常数;如果是模型树,其模型是一个线性方程
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val #
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    #将数据集分为两份,之后递归调用继续划分
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

def isTree(obj):
    return (type(obj).__name__=='dict')
 
def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0
     
def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    #if they are now both leafs, see if we can merge them
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        #如果两个分支合并后误差比未合并小,那么久合并两个分支
        if errorMerge < errorNoMerge: 
            print ("merging")
            return treeMean
        else: return tree
    else: return tree


if __name__ == '__main__':
    myData = loadDataSet('ex00.txt')
    myMat = mat(myData)
    tree = createTree(myMat)
    print(tree)

输出:

{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}

树减枝算法:

CART_regression_dec.py

'''
Created on 2018年8月1日

@author: hcl
'''
from numpy import *
import CART_regression

# myData2 = CART_regression.loadDataSet('ex2.txt')
# myMat2 = mat(myData2)
# myTree = CART_regression.createTree(myMat2, ops=(10000,4))
# print(myTree)


myData2 = CART_regression.loadDataSet('ex2.txt')
myMat2 = mat(myData2)
myTree = CART_regression.createTree(myMat2, ops=(0,1))
myDataTest = CART_regression.loadDataSet('ex2test.txt')
myMat2Test = mat(myDataTest)
end_tree = CART_regression.prune(myTree, myMat2Test)
print(end_tree)

输出:

...
merging
merging
merging
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.965969, 'left': 92.5239915, 'right': {'spInd': 0, 'spVal': 0.956951, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': {'spInd': 0, 'spVal': 0.960398, 'left': 112.386764, 'right': 123.559747}, 'right': 135.837013}, 'right': 111.2013225}}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.763328, 'left': {'spInd': 0, 'spVal': 0.769043, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.806158, 'left': {'spInd': 0, 'spVal': 0.815215, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.841547, 'left': {'spInd': 0, 'spVal': 0.841625, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': {'spInd': 0, 'spVal': 0.948822, 'left': 96.41885225, 'right': 69.318649}, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.936524, 'left': 110.03503850000001, 'right': {'spInd': 0, 'spVal': 0.934853, 'left': 65.548418, 'right': {'spInd': 0, 'spVal': 0.925782, 'left': 115.753994, 'right': {'spInd': 0, 'spVal': 0.910975, 'left': {'spInd': 0, 'spVal': 0.912161, 'left': 94.3961145, 'right': 85.005351}, 'right': {'spInd': 0, 'spVal': 0.901444, 'left': {'spInd': 0, 'spVal': 0.908629, 'left': 106.814667, 'right': 118.513475}, 'right': {'spInd': 0, 'spVal': 0.901421, 'left': 87.300625, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': {'spInd': 0, 'spVal': 0.900699, 'left': 100.133819, 'right': 108.094934}, 'right': {'spInd': 0, 'spVal': 0.888426, 'left': 82.436686, 'right': {'spInd': 0, 'spVal': 0.872199, 'left': 98.54454949999999, 'right': 106.16859550000001}}}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.84294, 'left': {'spInd': 0, 'spVal': 0.847219, 'left': 89.20993, 'right': 76.240984}, 'right': 95.893131}}}, 'right': 60.552308}, 'right': 124.87935300000001}, 'right': {'spInd': 0, 'spVal': 0.823848, 'left': 76.723835, 'right': {'spInd': 0, 'spVal': 0.819722, 'left': 59.342323, 'right': 70.054508}}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 118.319942, 'right': {'spInd': 0, 'spVal': 0.811363, 'left': 99.841379, 'right': 112.981216}}}, 'right': 73.49439925}, 'right': {'spInd': 0, 'spVal': 0.786865, 'left': 114.4008695, 'right': 102.26514075}}, 'right': 64.041941}, 'right': 115.199195}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.642373, 'left': {'spInd': 0, 'spVal': 0.642707, 'left': {'spInd': 0, 'spVal': 0.665329, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': {'spInd': 0, 'spVal': 0.70889, 'left': {'spInd': 0, 'spVal': 0.716211, 'left': 110.90283, 'right': {'spInd': 0, 'spVal': 0.710234, 'left': 103.345308, 'right': 108.553919}}, 'right': 135.416767}, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': {'spInd': 0, 'spVal': 0.69892, 'left': {'spInd': 0, 'spVal': 0.699873, 'left': {'spInd': 0, 'spVal': 0.70639, 'left': 106.180427, 'right': 105.062147}, 'right': 115.586605}, 'right': 92.470636}, 'right': {'spInd': 0, 'spVal': 0.689099, 'left': 120.521925, 'right': {'spInd': 0, 'spVal': 0.666452, 'left': 101.91115275, 'right': 112.78136649999999}}}}, 'right': {'spInd': 0, 'spVal': 0.661073, 'left': 121.980607, 'right': {'spInd': 0, 'spVal': 0.652462, 'left': 115.687524, 'right': 112.715799}}}, 'right': 82.500766}, 'right': 140.613941}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': {'spInd': 0, 'spVal': 0.623909, 'left': {'spInd': 0, 'spVal': 0.628061, 'left': {'spInd': 0, 'spVal': 0.637999, 'left': 82.713621, 'right': {'spInd': 0, 'spVal': 0.632691, 'left': 91.656617, 'right': 93.645293}}, 'right': {'spInd': 0, 'spVal': 0.624827, 'left': 117.628346, 'right': 105.970743}}, 'right': 82.04976400000001}, 'right': {'spInd': 0, 'spVal': 0.606417, 'left': 168.180746, 'right': {'spInd': 0, 'spVal': 0.513332, 'left': {'spInd': 0, 'spVal': 0.533511, 'left': {'spInd': 0, 'spVal': 0.548539, 'left': {'spInd': 0, 'spVal': 0.553797, 'left': {'spInd': 0, 'spVal': 0.560301, 'left': {'spInd': 0, 'spVal': 0.599142, 'left': 93.521396, 'right': {'spInd': 0, 'spVal': 0.589806, 'left': 130.378529, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 111.9849935, 'right': {'spInd': 0, 'spVal': 0.571214, 'left': 82.589328, 'right': {'spInd': 0, 'spVal': 0.569327, 'left': 114.872056, 'right': 108.435392}}}}}, 'right': 82.903945}, 'right': 129.0624485}, 'right': {'spInd': 0, 'spVal': 0.546601, 'left': 83.114502, 'right': {'spInd': 0, 'spVal': 0.537834, 'left': 97.3405265, 'right': 90.995536}}}, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.531944, 'left': 129.766743, 'right': 124.795495}, 'right': 116.176162}}, 'right': {'spInd': 0, 'spVal': 0.508548, 'left': 101.075609, 'right': {'spInd': 0, 'spVal': 0.508542, 'left': 93.292829, 'right': 96.403373}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.465561, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': {'spInd': 0, 'spVal': 0.483803, 'left': {'spInd': 0, 'spVal': 0.487381, 'left': 8.53677, 'right': 27.729263}, 'right': 5.224234}, 'right': {'spInd': 0, 'spVal': 0.46568, 'left': -9.712925, 'right': -23.777531}}, 'right': {'spInd': 0, 'spVal': 0.463241, 'left': 30.051931, 'right': 17.171057}}, 'right': {'spInd': 0, 'spVal': 0.455761, 'left': -34.044555, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.130626, 'left': {'spInd': 0, 'spVal': 0.382037, 'left': {'spInd': 0, 'spVal': 0.388789, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -4.1911745, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': {'spInd': 0, 'spVal': 0.418943, 'left': {'spInd': 0, 'spVal': 0.426711, 'left': {'spInd': 0, 'spVal': 0.428582, 'left': 19.745224, 'right': 15.224266}, 'right': -21.594268}, 'right': 44.161493}, 'right': {'spInd': 0, 'spVal': 0.403228, 'left': -26.419289, 'right': 0.6359300000000001}}}, 'right': 23.197474}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.370042, 'left': {'spInd': 0, 'spVal': 0.378965, 'left': -29.007783, 'right': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.377383, 'left': 13.583555, 'right': 5.241196}, 'right': -8.228297}}, 'right': {'spInd': 0, 'spVal': 0.35679, 'left': -32.124495, 'right': {'spInd': 0, 'spVal': 0.350725, 'left': -9.9938275, 'right': -26.851234812500003}}}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 22.286959625, 'right': {'spInd': 0, 'spVal': 0.309133, 'left': {'spInd': 0, 'spVal': 0.310956, 'left': -20.3973335, 'right': -49.939516}, 'right': {'spInd': 0, 'spVal': 0.131833, 'left': {'spInd': 0, 'spVal': 0.138619, 'left': {'spInd': 0, 'spVal': 0.156067, 'left': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.193282, 'left': {'spInd': 0, 'spVal': 0.211633, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': {'spInd': 0, 'spVal': 0.284794, 'left': {'spInd': 0, 'spVal': 0.300318, 'left': 8.814725, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -18.051318, 'right': {'spInd': 0, 'spVal': 0.295993, 'left': -1.798377, 'right': {'spInd': 0, 'spVal': 0.290749, 'left': -14.988279, 'right': -14.391613}}}}, 'right': {'spInd': 0, 'spVal': 0.273863, 'left': 35.623746, 'right': {'spInd': 0, 'spVal': 0.264926, 'left': -9.457556, 'right': {'spInd': 0, 'spVal': 0.264639, 'left': 5.280579, 'right': 2.557923}}}}, 'right': {'spInd': 0, 'spVal': 0.228628, 'left': {'spInd': 0, 'spVal': 0.228751, 'left': -9.601409499999999, 'right': -30.812912}, 'right': -2.266273}}, 'right': 6.099239}, 'right': {'spInd': 0, 'spVal': 0.202161, 'left': -16.42737025, 'right': -2.6781805}}, 'right': 9.5773855}, 'right': {'spInd': 0, 'spVal': 0.156273, 'left': {'spInd': 0, 'spVal': 0.164134, 'left': {'spInd': 0, 'spVal': 0.166431, 'left': -14.740059, 'right': -6.512506}, 'right': -27.405211}, 'right': 0.225886}}, 'right': {'spInd': 0, 'spVal': 0.13988, 'left': 7.557349, 'right': 7.336784}}, 'right': -29.087463}, 'right': 22.478291}}}}}, 'right': -39.524461}, 'right': {'spInd': 0, 'spVal': 0.124723, 'left': 22.891675, 'right': {'spInd': 0, 'spVal': 0.085111, 'left': {'spInd': 0, 'spVal': 0.108801, 'left': 6.196516, 'right': {'spInd': 0, 'spVal': 0.10796, 'left': -16.106164, 'right': {'spInd': 0, 'spVal': 0.085873, 'left': -1.293195, 'right': -10.137104}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 37.820659, 'right': {'spInd': 0, 'spVal': 0.080061, 'left': -24.132226, 'right': {'spInd': 0, 'spVal': 0.068373, 'left': 15.824970500000001, 'right': {'spInd': 0, 'spVal': 0.061219, 'left': -15.160836, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': {'spInd': 0, 'spVal': 0.053764, 'left': {'spInd': 0, 'spVal': 0.055862, 'left': 6.695567, 'right': -3.131497}, 'right': -13.731698}, 'right': 4.091626}}}}}}}}}}}

模型树算法:

model_tree.py

'''
Created on 2018年8月1日

@author: hcl
'''
from numpy import *
import CART_regression

def linearSolve(dataSet):
    '''将数据集格式化成目标变量Y和自变量X,X、Y用于执行简单线性回归'''
    m,n = shape(dataSet)
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#默认最后一列为Y
    xTx = X.T*X
    #若矩阵的逆不存在,抛异常
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)#回归系数
    return ws,X,Y

def modelLeaf(dataSet):
    '''负责生成叶节点模型'''
    ws,X,Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    '''误差计算函数'''
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

def regTreeEval(model, inDat):
    #为了和modeTreeEval()保持一致,保留两个输入参数
    return float(model)

def modelTreeEval(model, inDat):
    #对输入数据进行格式化处理,在原数据矩阵上增加第0列,元素的值都是1
    n = shape(inDat)[1]
    X = mat(ones((1,n+1)))
    X[:,1:n+1]=inDat
    return float(X*model)

def createForeCast(tree, testData, modelEval=regTreeEval):
    # 多次调用treeForeCast()函数,以向量形式返回预测值,在整个测试集进行预测非常有用
    m=len(testData)
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

def treeForeCast(tree, inData, modelEval=regTreeEval):
    '''
    # 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
    # modeEval是对叶节点进行预测的函数引用,指定树的类型,以便在叶节点上调用合适的模型。
    # 此函数自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上
    # 调用modelEval()函数,该函数的默认值为regTreeEval()    
    '''
    if not CART_regression.isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if CART_regression.isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
        else: return modelEval(tree['left'], inData)
    else:
        if CART_regression.isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
        else: return modelEval(tree['right'], inData)


if __name__ == '__main__':
#     myMat2 = mat(CART_regression.loadDataSet('exp2.txt'))
#     tree = CART_regression.createTree(myMat2,modelLeaf,modelErr,(1,10))
#     print(tree)

    #回归树
    trainMat = mat(CART_regression.loadDataSet('bikeSpeedVsIq_train.txt'))
    testMat = mat(CART_regression.loadDataSet('bikeSpeedVsIq_test.txt'))
    myTree = CART_regression.createTree(trainMat, ops=(1,20))
    yHat = createForeCast(myTree, testMat[:,0])
    corr = corrcoef(yHat, testMat[:,1], rowvar=0)
    print('regression tree corr:',corr)
    
    #模型树
    myTree = CART_regression.createTree(trainMat, modelLeaf, modelErr, (1,20))
    yHat = createForeCast(myTree, testMat[:,0], modelTreeEval)
    corr = corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
    print('model tree corr:',corr)
    
    # 标准回归
    ws, X, Y = linearSolve(trainMat)
    print('ws:',ws)
    for i in range(shape(testMat)[0]) :
        yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]
    corr = corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
    print('standardization tree corr:',corr)
    
    

输出:

regression tree corr: [[1.         0.96408523]
 [0.96408523 1.        ]]
model tree corr: 0.9760412191380629
ws: [[37.58916794]
 [ 6.18978355]]
standardization tree corr: 0.9434684235674766

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值