主要内容:CART算法、回归与模型树、树减枝算法
并在最后进行了 回归树、模型树以及标准回归之间的比较
CART算法:
CART_regression.py
'''
Created on 2018年8月1日
@author: hcl
'''
from numpy import *
def loadDataSet(fileName):
'''
读取一个一tab键为分隔符的文件,然后将每行的内容保存成一组浮点数
'''
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float,curLine))
dataMat.append(fltLine)
return dataMat
def regLeaf(dataSet):
'''负责生成叶节点'''
#当chooseBestSplit()函数确定不再对数据进行切分时,将调用本函数来得到叶节点的模型。
#在回归树中,该模型其实就是目标变量的均值。
return mean(dataSet[:,-1])
def regErr(dataSet):
'''
误差估计函数,该函数在给定的数据上计算目标变量的平方误差,这里直接调用均方差函数
'''
return var(dataSet[:,-1]) * shape(dataSet)[0]#返回总方差
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
'''
用最佳方式切分数据集和生成相应的叶节点
'''
#ops为用户指定参数,用于控制函数的停止时机
tolS = ops[0]; tolN = ops[1]
#如果所有值相等则退出
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
#在所有可能的特征及其可能取值上遍历,找到最佳的切分方式
#最佳切分也就是使得切分后能达到最低误差的切分
for featIndex in range(n-1):
# for splitVal in set(dataSet[:,featIndex]):
for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#如果误差减小不大则退出
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
#如果切分出的数据集很小则退出
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
#提前终止条件都不满足,返回切分特征和特征值
return bestIndex,bestValue
def binSplitDataSet(dataSet, feature, value):
'''
数据集切分函数
'''
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
'''
树构建函数
leafType:建立叶节点的函数
errType:误差计算函数
ops:包含树构建所需其他参数的元组
'''
#选择最优的划分特征
#如果满足停止条件,将返回None和某类模型的值
#若构建的是回归树,该模型是一个常数;如果是模型树,其模型是一个线性方程
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
if feat == None: return val #
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
#将数据集分为两份,之后递归调用继续划分
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
def isTree(obj):
return (type(obj).__name__=='dict')
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#if they are now both leafs, see if we can merge them
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
#如果两个分支合并后误差比未合并小,那么久合并两个分支
if errorMerge < errorNoMerge:
print ("merging")
return treeMean
else: return tree
else: return tree
if __name__ == '__main__':
myData = loadDataSet('ex00.txt')
myMat = mat(myData)
tree = createTree(myMat)
print(tree)
输出:
{'spInd': 0, 'spVal': 0.48813, 'left': 1.0180967672413792, 'right': -0.04465028571428572}
树减枝算法:
CART_regression_dec.py
'''
Created on 2018年8月1日
@author: hcl
'''
from numpy import *
import CART_regression
# myData2 = CART_regression.loadDataSet('ex2.txt')
# myMat2 = mat(myData2)
# myTree = CART_regression.createTree(myMat2, ops=(10000,4))
# print(myTree)
myData2 = CART_regression.loadDataSet('ex2.txt')
myMat2 = mat(myData2)
myTree = CART_regression.createTree(myMat2, ops=(0,1))
myDataTest = CART_regression.loadDataSet('ex2test.txt')
myMat2Test = mat(myDataTest)
end_tree = CART_regression.prune(myTree, myMat2Test)
print(end_tree)
输出:
...
merging
merging
merging
{'spInd': 0, 'spVal': 0.499171, 'left': {'spInd': 0, 'spVal': 0.729397, 'left': {'spInd': 0, 'spVal': 0.952833, 'left': {'spInd': 0, 'spVal': 0.965969, 'left': 92.5239915, 'right': {'spInd': 0, 'spVal': 0.956951, 'left': {'spInd': 0, 'spVal': 0.958512, 'left': {'spInd': 0, 'spVal': 0.960398, 'left': 112.386764, 'right': 123.559747}, 'right': 135.837013}, 'right': 111.2013225}}, 'right': {'spInd': 0, 'spVal': 0.759504, 'left': {'spInd': 0, 'spVal': 0.763328, 'left': {'spInd': 0, 'spVal': 0.769043, 'left': {'spInd': 0, 'spVal': 0.790312, 'left': {'spInd': 0, 'spVal': 0.806158, 'left': {'spInd': 0, 'spVal': 0.815215, 'left': {'spInd': 0, 'spVal': 0.833026, 'left': {'spInd': 0, 'spVal': 0.841547, 'left': {'spInd': 0, 'spVal': 0.841625, 'left': {'spInd': 0, 'spVal': 0.944221, 'left': {'spInd': 0, 'spVal': 0.948822, 'left': 96.41885225, 'right': 69.318649}, 'right': {'spInd': 0, 'spVal': 0.85497, 'left': {'spInd': 0, 'spVal': 0.936524, 'left': 110.03503850000001, 'right': {'spInd': 0, 'spVal': 0.934853, 'left': 65.548418, 'right': {'spInd': 0, 'spVal': 0.925782, 'left': 115.753994, 'right': {'spInd': 0, 'spVal': 0.910975, 'left': {'spInd': 0, 'spVal': 0.912161, 'left': 94.3961145, 'right': 85.005351}, 'right': {'spInd': 0, 'spVal': 0.901444, 'left': {'spInd': 0, 'spVal': 0.908629, 'left': 106.814667, 'right': 118.513475}, 'right': {'spInd': 0, 'spVal': 0.901421, 'left': 87.300625, 'right': {'spInd': 0, 'spVal': 0.892999, 'left': {'spInd': 0, 'spVal': 0.900699, 'left': 100.133819, 'right': 108.094934}, 'right': {'spInd': 0, 'spVal': 0.888426, 'left': 82.436686, 'right': {'spInd': 0, 'spVal': 0.872199, 'left': 98.54454949999999, 'right': 106.16859550000001}}}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.84294, 'left': {'spInd': 0, 'spVal': 0.847219, 'left': 89.20993, 'right': 76.240984}, 'right': 95.893131}}}, 'right': 60.552308}, 'right': 124.87935300000001}, 'right': {'spInd': 0, 'spVal': 0.823848, 'left': 76.723835, 'right': {'spInd': 0, 'spVal': 0.819722, 'left': 59.342323, 'right': 70.054508}}}, 'right': {'spInd': 0, 'spVal': 0.811602, 'left': 118.319942, 'right': {'spInd': 0, 'spVal': 0.811363, 'left': 99.841379, 'right': 112.981216}}}, 'right': 73.49439925}, 'right': {'spInd': 0, 'spVal': 0.786865, 'left': 114.4008695, 'right': 102.26514075}}, 'right': 64.041941}, 'right': 115.199195}, 'right': 78.08564325}}, 'right': {'spInd': 0, 'spVal': 0.640515, 'left': {'spInd': 0, 'spVal': 0.642373, 'left': {'spInd': 0, 'spVal': 0.642707, 'left': {'spInd': 0, 'spVal': 0.665329, 'left': {'spInd': 0, 'spVal': 0.706961, 'left': {'spInd': 0, 'spVal': 0.70889, 'left': {'spInd': 0, 'spVal': 0.716211, 'left': 110.90283, 'right': {'spInd': 0, 'spVal': 0.710234, 'left': 103.345308, 'right': 108.553919}}, 'right': 135.416767}, 'right': {'spInd': 0, 'spVal': 0.698472, 'left': {'spInd': 0, 'spVal': 0.69892, 'left': {'spInd': 0, 'spVal': 0.699873, 'left': {'spInd': 0, 'spVal': 0.70639, 'left': 106.180427, 'right': 105.062147}, 'right': 115.586605}, 'right': 92.470636}, 'right': {'spInd': 0, 'spVal': 0.689099, 'left': 120.521925, 'right': {'spInd': 0, 'spVal': 0.666452, 'left': 101.91115275, 'right': 112.78136649999999}}}}, 'right': {'spInd': 0, 'spVal': 0.661073, 'left': 121.980607, 'right': {'spInd': 0, 'spVal': 0.652462, 'left': 115.687524, 'right': 112.715799}}}, 'right': 82.500766}, 'right': 140.613941}, 'right': {'spInd': 0, 'spVal': 0.613004, 'left': {'spInd': 0, 'spVal': 0.623909, 'left': {'spInd': 0, 'spVal': 0.628061, 'left': {'spInd': 0, 'spVal': 0.637999, 'left': 82.713621, 'right': {'spInd': 0, 'spVal': 0.632691, 'left': 91.656617, 'right': 93.645293}}, 'right': {'spInd': 0, 'spVal': 0.624827, 'left': 117.628346, 'right': 105.970743}}, 'right': 82.04976400000001}, 'right': {'spInd': 0, 'spVal': 0.606417, 'left': 168.180746, 'right': {'spInd': 0, 'spVal': 0.513332, 'left': {'spInd': 0, 'spVal': 0.533511, 'left': {'spInd': 0, 'spVal': 0.548539, 'left': {'spInd': 0, 'spVal': 0.553797, 'left': {'spInd': 0, 'spVal': 0.560301, 'left': {'spInd': 0, 'spVal': 0.599142, 'left': 93.521396, 'right': {'spInd': 0, 'spVal': 0.589806, 'left': 130.378529, 'right': {'spInd': 0, 'spVal': 0.582311, 'left': 111.9849935, 'right': {'spInd': 0, 'spVal': 0.571214, 'left': 82.589328, 'right': {'spInd': 0, 'spVal': 0.569327, 'left': 114.872056, 'right': 108.435392}}}}}, 'right': 82.903945}, 'right': 129.0624485}, 'right': {'spInd': 0, 'spVal': 0.546601, 'left': 83.114502, 'right': {'spInd': 0, 'spVal': 0.537834, 'left': 97.3405265, 'right': 90.995536}}}, 'right': {'spInd': 0, 'spVal': 0.51915, 'left': {'spInd': 0, 'spVal': 0.531944, 'left': 129.766743, 'right': 124.795495}, 'right': 116.176162}}, 'right': {'spInd': 0, 'spVal': 0.508548, 'left': 101.075609, 'right': {'spInd': 0, 'spVal': 0.508542, 'left': 93.292829, 'right': 96.403373}}}}}}}, 'right': {'spInd': 0, 'spVal': 0.457563, 'left': {'spInd': 0, 'spVal': 0.465561, 'left': {'spInd': 0, 'spVal': 0.467383, 'left': {'spInd': 0, 'spVal': 0.483803, 'left': {'spInd': 0, 'spVal': 0.487381, 'left': 8.53677, 'right': 27.729263}, 'right': 5.224234}, 'right': {'spInd': 0, 'spVal': 0.46568, 'left': -9.712925, 'right': -23.777531}}, 'right': {'spInd': 0, 'spVal': 0.463241, 'left': 30.051931, 'right': 17.171057}}, 'right': {'spInd': 0, 'spVal': 0.455761, 'left': -34.044555, 'right': {'spInd': 0, 'spVal': 0.126833, 'left': {'spInd': 0, 'spVal': 0.130626, 'left': {'spInd': 0, 'spVal': 0.382037, 'left': {'spInd': 0, 'spVal': 0.388789, 'left': {'spInd': 0, 'spVal': 0.437652, 'left': -4.1911745, 'right': {'spInd': 0, 'spVal': 0.412516, 'left': {'spInd': 0, 'spVal': 0.418943, 'left': {'spInd': 0, 'spVal': 0.426711, 'left': {'spInd': 0, 'spVal': 0.428582, 'left': 19.745224, 'right': 15.224266}, 'right': -21.594268}, 'right': 44.161493}, 'right': {'spInd': 0, 'spVal': 0.403228, 'left': -26.419289, 'right': 0.6359300000000001}}}, 'right': 23.197474}, 'right': {'spInd': 0, 'spVal': 0.335182, 'left': {'spInd': 0, 'spVal': 0.370042, 'left': {'spInd': 0, 'spVal': 0.378965, 'left': -29.007783, 'right': {'spInd': 0, 'spVal': 0.373501, 'left': {'spInd': 0, 'spVal': 0.377383, 'left': 13.583555, 'right': 5.241196}, 'right': -8.228297}}, 'right': {'spInd': 0, 'spVal': 0.35679, 'left': -32.124495, 'right': {'spInd': 0, 'spVal': 0.350725, 'left': -9.9938275, 'right': -26.851234812500003}}}, 'right': {'spInd': 0, 'spVal': 0.324274, 'left': 22.286959625, 'right': {'spInd': 0, 'spVal': 0.309133, 'left': {'spInd': 0, 'spVal': 0.310956, 'left': -20.3973335, 'right': -49.939516}, 'right': {'spInd': 0, 'spVal': 0.131833, 'left': {'spInd': 0, 'spVal': 0.138619, 'left': {'spInd': 0, 'spVal': 0.156067, 'left': {'spInd': 0, 'spVal': 0.166765, 'left': {'spInd': 0, 'spVal': 0.193282, 'left': {'spInd': 0, 'spVal': 0.211633, 'left': {'spInd': 0, 'spVal': 0.228473, 'left': {'spInd': 0, 'spVal': 0.25807, 'left': {'spInd': 0, 'spVal': 0.284794, 'left': {'spInd': 0, 'spVal': 0.300318, 'left': 8.814725, 'right': {'spInd': 0, 'spVal': 0.297107, 'left': -18.051318, 'right': {'spInd': 0, 'spVal': 0.295993, 'left': -1.798377, 'right': {'spInd': 0, 'spVal': 0.290749, 'left': -14.988279, 'right': -14.391613}}}}, 'right': {'spInd': 0, 'spVal': 0.273863, 'left': 35.623746, 'right': {'spInd': 0, 'spVal': 0.264926, 'left': -9.457556, 'right': {'spInd': 0, 'spVal': 0.264639, 'left': 5.280579, 'right': 2.557923}}}}, 'right': {'spInd': 0, 'spVal': 0.228628, 'left': {'spInd': 0, 'spVal': 0.228751, 'left': -9.601409499999999, 'right': -30.812912}, 'right': -2.266273}}, 'right': 6.099239}, 'right': {'spInd': 0, 'spVal': 0.202161, 'left': -16.42737025, 'right': -2.6781805}}, 'right': 9.5773855}, 'right': {'spInd': 0, 'spVal': 0.156273, 'left': {'spInd': 0, 'spVal': 0.164134, 'left': {'spInd': 0, 'spVal': 0.166431, 'left': -14.740059, 'right': -6.512506}, 'right': -27.405211}, 'right': 0.225886}}, 'right': {'spInd': 0, 'spVal': 0.13988, 'left': 7.557349, 'right': 7.336784}}, 'right': -29.087463}, 'right': 22.478291}}}}}, 'right': -39.524461}, 'right': {'spInd': 0, 'spVal': 0.124723, 'left': 22.891675, 'right': {'spInd': 0, 'spVal': 0.085111, 'left': {'spInd': 0, 'spVal': 0.108801, 'left': 6.196516, 'right': {'spInd': 0, 'spVal': 0.10796, 'left': -16.106164, 'right': {'spInd': 0, 'spVal': 0.085873, 'left': -1.293195, 'right': -10.137104}}}, 'right': {'spInd': 0, 'spVal': 0.084661, 'left': 37.820659, 'right': {'spInd': 0, 'spVal': 0.080061, 'left': -24.132226, 'right': {'spInd': 0, 'spVal': 0.068373, 'left': 15.824970500000001, 'right': {'spInd': 0, 'spVal': 0.061219, 'left': -15.160836, 'right': {'spInd': 0, 'spVal': 0.044737, 'left': {'spInd': 0, 'spVal': 0.053764, 'left': {'spInd': 0, 'spVal': 0.055862, 'left': 6.695567, 'right': -3.131497}, 'right': -13.731698}, 'right': 4.091626}}}}}}}}}}}
模型树算法:
model_tree.py
'''
Created on 2018年8月1日
@author: hcl
'''
from numpy import *
import CART_regression
def linearSolve(dataSet):
'''将数据集格式化成目标变量Y和自变量X,X、Y用于执行简单线性回归'''
m,n = shape(dataSet)
X = mat(ones((m,n))); Y = mat(ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]#默认最后一列为Y
xTx = X.T*X
#若矩阵的逆不存在,抛异常
if linalg.det(xTx) == 0.0:
raise NameError('This matrix is singular, cannot do inverse,\n\
try increasing the second value of ops')
ws = xTx.I * (X.T * Y)#回归系数
return ws,X,Y
def modelLeaf(dataSet):
'''负责生成叶节点模型'''
ws,X,Y = linearSolve(dataSet)
return ws
def modelErr(dataSet):
'''误差计算函数'''
ws,X,Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y - yHat,2))
def regTreeEval(model, inDat):
#为了和modeTreeEval()保持一致,保留两个输入参数
return float(model)
def modelTreeEval(model, inDat):
#对输入数据进行格式化处理,在原数据矩阵上增加第0列,元素的值都是1
n = shape(inDat)[1]
X = mat(ones((1,n+1)))
X[:,1:n+1]=inDat
return float(X*model)
def createForeCast(tree, testData, modelEval=regTreeEval):
# 多次调用treeForeCast()函数,以向量形式返回预测值,在整个测试集进行预测非常有用
m=len(testData)
yHat = mat(zeros((m,1)))
for i in range(m):
yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
def treeForeCast(tree, inData, modelEval=regTreeEval):
'''
# 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
# modeEval是对叶节点进行预测的函数引用,指定树的类型,以便在叶节点上调用合适的模型。
# 此函数自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上
# 调用modelEval()函数,该函数的默认值为regTreeEval()
'''
if not CART_regression.isTree(tree): return modelEval(tree, inData)
if inData[tree['spInd']] > tree['spVal']:
if CART_regression.isTree(tree['left']): return treeForeCast(tree['left'], inData, modelEval)
else: return modelEval(tree['left'], inData)
else:
if CART_regression.isTree(tree['right']): return treeForeCast(tree['right'], inData, modelEval)
else: return modelEval(tree['right'], inData)
if __name__ == '__main__':
# myMat2 = mat(CART_regression.loadDataSet('exp2.txt'))
# tree = CART_regression.createTree(myMat2,modelLeaf,modelErr,(1,10))
# print(tree)
#回归树
trainMat = mat(CART_regression.loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(CART_regression.loadDataSet('bikeSpeedVsIq_test.txt'))
myTree = CART_regression.createTree(trainMat, ops=(1,20))
yHat = createForeCast(myTree, testMat[:,0])
corr = corrcoef(yHat, testMat[:,1], rowvar=0)
print('regression tree corr:',corr)
#模型树
myTree = CART_regression.createTree(trainMat, modelLeaf, modelErr, (1,20))
yHat = createForeCast(myTree, testMat[:,0], modelTreeEval)
corr = corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
print('model tree corr:',corr)
# 标准回归
ws, X, Y = linearSolve(trainMat)
print('ws:',ws)
for i in range(shape(testMat)[0]) :
yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0]
corr = corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
print('standardization tree corr:',corr)
输出:
regression tree corr: [[1. 0.96408523]
[0.96408523 1. ]]
model tree corr: 0.9760412191380629
ws: [[37.58916794]
[ 6.18978355]]
standardization tree corr: 0.9434684235674766