第九章 树回归
- CART算法
- 回归与模型树
- 树减枝算法
- python中GUI的使用
线性回归需要拟合所有的样本点(局部加权线性回归除外),当数据拥有众多特征并且特征之间关系十分复杂时,就不可能使用全局线性模型来拟合任何数据。
将数据集切分成很多份易建模的数据,再用线性回归技术来建模可破。
本章介绍CART(Classification And Regression Trees, 分类回归树)的树构建算法,可用于分类还可用于回归。
9.1 复杂数据的局部性建模
chap3的决策树主要是不断将数据切分成小数据集,直到所有目标变量完全相同,或者数据不能再切分为止。决策树是一种贪心算法,并不考虑能否达到全局最优。其构建算法ID3算法每次选取当前最佳的特征来分割数据,并按照该特征的所有可能取值来划分,之后该特征不会再起作用。另外一种方法是二元切分法,每次把数据集切成两份,如果数据的某特征等于切分所要求的值,那么这些数据就进入树的左子树,反之右子树。二元切分法可处理连续型特征,节省树的构建时间。
CART使用二元切分来处理连续型变量,应用广泛。
9.2 连续型和离散型特征的树的构建
与chap3类似,用字典存储树的结构。包括4元素:
- 待切分的特征
- 待切分的特征值
- 左子树。当不再需要切分的时候,可是个单个值。
- 右子树。类似左子树。
CART算法可固定树的数据结构,树包含左键和右键,可以存储 另一颗子树或者单个值。
伪代码:
找到最佳的待切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
coding:
#!/usr/bin/env python # coding=utf-8 from numpy import * def loadDataSet(fileName): dataMat = [] fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split("\n") fltLine = map(float, curLine) #将每行的每个元素映射为浮点数 dataMat.append(fltLine) return dataMat def binSplitDataSet(dataSet, feature,value):#数据集合,待切分的特征,特征值,将数据集合切分得到两个子集 mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0], :][0] mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0], :][0] return mat0, mat1 def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)): feat, val = chooseBestSplit(dataSet, leafType, errType, ops) #将数据集进行切分 if feat == None: return val else: retTree = {} retTree["spInd"] = feat retTree["spVal"] = val lSet, rSet = binSplitDataSet(dataSet, feat, val) retTree["left"] = createTree(lSet, leafType, errType, ops) #递归切分 retTree["right"] = createTree(rSet, leafType, errType, ops) return retTree testMat = mat(eye(4)) mat0, mat1 = binSplitDataSet(testMat, 1, 0.5) # print testMat print mat0 print mat1
9.3 将CART算法用于回归
回归树假设叶节点是常数值。这种策略认为数据中的复杂关系可以用树结构来概括。
为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。首先计算所有数据的均值,然后计算每条 数据的值到均值的差值,一般使用绝对值或平方差 来代替差值 ,类似方差计算,方差为平方误差的均值(均方差),这里需要计算平方误差的总值(总方差),均方差(var函数)乘以数据集样本点的个数可破。
构建树
chooseBestSplit()函数目标是找到数据集切分的最佳位置,遍历所有的特征及其可能的取值来找到使误差最小化的切分阈值。
伪代码:
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差 返回最佳切分的特征和阈值
数据:
Figure 9-1: 实验数据部分样本数据
coding:
#==============回归树的切分函数============================= def regLeaf(dataSet): return mean(dataSet[:,-1]) #返回叶节点,回归树中的目标变量的均值 def regErr(dataSet): return var(dataSet[:,-1])*shape(dataSet)[0] #误差估计,计算目标变量的平方误差,需要返回总误差 def chooseBestSplit(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)): tolS = ops[0] #容许的误差下降值 tolN = ops[1] #切分的最少样本数 if len(set(dataSet[:,-1].T.tolist()[0]))==1: #如果目标值相等,退出 return None, leafType(dataSet) else: m,n = shape(dataSet) S = errType(dataSet) bestS = inf bestIndex = 0 bestValue = 0 for featIndex in range(n-1): #对所有特征进行遍历 for splitVal in set(dataSet[:,featIndex]): #遍历某个特征的所有特征值 mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal) #按照某个特征的某个值将数据切分成两个数据子集 if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN): #如果某个子集行数不大于tolN,也不应该切分 continue newS = errType(mat0) + errType(mat1) #新误差由切分后的两个数据子集组成的误差 if newS < bestS: #判断新切分能否降低误差, bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) <tolS: #如果误差不大则退出 return None, leafType(dataSet) mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN): #如果切分出的数据集很小则退出 return None, leafType(dataSet) return bestIndex, bestValue #===================================================================== myDat = loadDataSet("ex00.txt") myMat = mat(myDat) print createTree(myMat) myDat1 = loadDataSet("ex0.txt") myMat1 = mat(myDat1) print createTree(myMat1)
切分效果:
Figure 9-2: 切分效果
9.4 树剪枝
一棵树如果节点过多,表明该模型可能对数据进行了”过拟合“。可使用测试集上交叉验证来发现过拟合。
剪枝(pruning):降低决策树的复杂度来避免过拟合。分为预剪枝(prepruning)和后剪枝(postpruning),后者需要使用训练集和测试集。
预剪枝
树构建算法对输入参数tolS和tolN非常敏感,通过不断地修改停止条件来得到合理结果并不是很好的办法。
后剪枝
后剪枝需要使用测试集。首先指定参数,使得构建出的树足够大,足够复杂,便于剪枝。从上而下找到叶节点,用测试集来判断这些叶节点合并是否能降低测试误差。是则合并。
伪代码:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的化,就将叶节点合并
coding:
#=============回归树剪枝函数========================================== def isTree(obj): #测试输入变量是否是一棵树,用于判断当前处理的节点是否是叶节点 return (type(obj).__name__=="dict") def getMean(tree): #递归从上到下,到叶节点,找到两个叶节点则计算它们的平均值 if isTree(tree["right"]): tree["right"] = getMean(tree["right"]) if isTree(tree["left"]): tree["left"] = getMean(tree["left"]) return (tree["left"]+tree["right"])/2.0 def prune(tree, testData): if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet) if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet) #if they are now both leafs, see if we can merge them if not isTree(tree['left']) and not isTree(tree['right']): lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\ sum(power(rSet[:,-1] - tree['right'],2)) treeMean = (tree['left']+tree['right'])/2.0 errorMerge = sum(power(testData[:,-1] - treeMean,2)) if errorMerge < errorNoMerge: print "merging" return treeMean else: return tree else: return tree myTree = createTree(myMat2, ops=(0,1)) myDatTest = loadDataSet("ex2test.txt") myMat2Test = mat(myDatTest) print prune(myTree, myMat2Test)
9.5 模型树
用树来对数据建模,除了把叶节点设定为常数值外,还可以将其设定为分段线性函数,分段线性(piecewise linear)即模型由多个线性片段组成。
可以设计两条分别从0.0-0.3、从0.3~1.0的直线,得到两个线性模型,即分段线性模型。
两条直线比很多节点组成一颗大树更容易理解。模型树的可解释性是它优于回归树的特点之一。模型树也具有更高的预测准确度。利用树生成算法对数据进行切分,且每份切分数据都能很容易被线性模型所表示,关键在于找到最佳切分。
coding:
效果:
#==========模型树的叶节点生成函数========= def linearSolve(dataSet): #执行简单的线性回归 m,n = shape(dataSet) X = mat(ones((m,n))) Y = mat(ones((m,1))) X[:, 1:n] = dataSet[:, 0:n-1] Y = dataSet[:, -1] #将X,Y中的数据格式化 xTx = X.T*X if linalg.det(xTx) == 0.0: raise NameError("This matrix is singular, cannot do inverse") ws = linalg.pinv(xTx)*(X.T*Y) ws = xTx.I*(X.T*Y) return ws, X, Y def modelLeaf(dataSet): #当数据不再需要切分的时候它负责生成叶节点模型 ws, X, Y = linearSolve(dataSet) return ws def modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X*ws return sum(power(Y - yHat,2)) #计算平方误差 myMat2 = mat(loadDataSet("exp2.txt")) #print myMat2,type(myMat2) print createTree(myMat2, modelLeaf, modelErr, (1,10))
两个线性模型分别为y=3.468+1.185x和y=0.00168+11.964x,实际数据是由模型y=3.5+1.0x和y=0+12x再加上高斯噪音生成,可以看出效果还是不错。
9.6 示例:树回归与标准回归的比较
计算模型树、回归树及其他模型效果,比较客观的方法是计算相关系数,R*R值,Numpy中corrcoef(yHat, y, rowvar = 0)也即皮尔逊相关系数。
coding:
#================用树回归进行预测的代码============= def regTreeEval(model, inDat): return float(model) def modelTreeEval(model, inDat): n = shape(inDat)[1] X = mat(ones((1,n+1))) X[:,1:n+1] = inDat return float(X*model) def treeForeCast(tree, inData, modelEval=regTreeEval): if not isTree(tree): return modelEval(tree, inData) #如果输入单个数据或行向量,返回一个浮点值 else: if inData[tree["spInd"]] > tree["spVal"]: if isTree(tree["left"]): return treeForeCast(tree["left"], inData, modelEval) else: return modelEval(tree["left"], inData) else: if isTree(tree["right"]): return treeForeCast(tree["right"], inData, modelEval) else: return modelEval(tree["right"], inData) def createForeCast(tree, testData, modelEval=regTreeEval): m = len(testData) yHat = mat(zeros((m,1))) for i in range(m): yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)#多次调用treeForeCast函数,将结果以列的形式放到yHat变量中 return yHat trainMat = mat(loadDataSet("bikeSpeedVsIq_train.txt")) testMat = mat(loadDataSet("bikeSpeedVsIq_test.txt")) myTree = createTree(trainMat, ops=(1,20)) yHat = createForeCast(myTree, testMat[:,0]) print "回归树的皮尔逊相关系数:",corrcoef(yHat, testMat[:,1], rowvar=0)[0,1] myTree = createTree(trainMat, modelLeaf, modelErr,(1,20)) yHat = createForeCast(myTree, testMat[:,0], modelTreeEval) print "模型树的皮尔逊相关系数:",corrcoef(yHat, testMat[:,1], rowvar=0)[0,1] ws, X, Y = linearSolve(trainMat) print "线性回归系数:",ws for i in range(shape(testMat)[0]): yHat[i] = testMat[i,0]*ws[1,0] + ws[0,0] print "线性回归模型的皮尔逊相关系数:",corrcoef(yHat, testMat[:,1], rowvar=0)[0,1]
效果:
Figure 9-5: 回归树、模型树、简单线性回归的皮尔逊相关系数
可以看出模型树的结果比回归树效果好。
9.7 使用Python的Tkinter库创建GUI
本小结用到python的一个图形用户界面(GUI, Graphical User Interface)框架——Tkinter。
Tkinter的GUI由一些小部件组成(Widge)组成。小部件:文本框、按钮、标签和复选按钮等对象。
myLabel调用grid()方法时,把myLabel的位置告诉了布局管理器,grid()函数会把小部件安排在一个二维表格中。
coding:
#!/usr/bin/env python # coding=utf-8 #用于构建树管理器界面的Tkinter小部件 from numpy import * from Tkinter import * import regTrees def reDraw(tolS, tolN): pass def drawNewTree(): pass root = Tk() Label(root, text = "Plot Place Holder").grid(row=0, columnspan=3) #设置文本,第0行,距0的行值为3, Label(root, text = "tolN").grid(row=1, column=0) tolNentry = Entry(root) #Entry为允许单行文本输入的文本框,设置文本框,再定位置第1行第1列,再插入数值 tolNentry.grid(row=1, column=1) tolNentry.insert(0,"10") Label(root, text="tolS").grid(row=2, column=0) tolSentry = Entry(root) tolSentry.grid(row=2, column=1) tolSentry.insert(0,"1.0") Button(root, text = "ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)#Botton按钮,设置第1行第2列,列值为3 chkBtnVar = IntVar() #IntVar为按钮整数值小部件 chkBtn = Checkbutton(root, text = "Model Tree", variable = chkBtnVar) chkBtn.grid(row=3, column=0, columnspan = 2) reDraw.rawDat = mat(regTrees.loadDataSet("sine.txt")) reDraw.testDat = arange(min(reDraw.rawDat[:,0]), max(reDraw.rawDat[:,0]),0.01) reDraw(1.0,10) root.mainloop()
效果:
集成Matplotlib和Tkinter
matplotlib绘制的图像可以放到GUI上。matplotlib构建程序时包含一个前端,如plot、scatter函数,也同时创建了一个后端,用于实现绘图和不同应用之间的接口,改变后端可以将图像绘制在PNG,PDF,SVG等格式的文件上。matplotlib将后端设置为TkAgg,TkAgg可以在所选GUI框架上调用Agg,把Agg呈现在画布上。
coding:
import matplotlib matplotlib.use("TkAgg") #设定后端为TkAgg from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure import Figure def reDraw(tolS,tolN): reDraw.f.clf() #清空之前的图像 reDraw.a = reDraw.f.add_subplot(111) #重新添加子图 if chkBtnVar.get(): #检查复选框是否选中,确定是模型树还是回归树 if tolN<2: tolN=2 myTree = regTrees.createTree(reDraw.rawDat, regTrees.modelLeaf, regTrees.modelErr,(tolS,tolN)) yHat = regTrees.createForeCast(myTree, reDraw.testDat, regTrees.modelTreeEval) else: myTree = regTrees.createTree(reDraw.rawDat, ops=(tolS,tolN)) yHat = regTrees.createForeCast(myTree, reDraw.testDat) reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1],s=5) #画真实值的散点图 reDraw.a.plot(reDraw.testDat,yHat,linewidth=2.0) #画预测值的直线图 reDraw.canvas.show() def getInputs(): #获取用户输入的值,tolN期望得到整数值,tolS期望得到浮点数, try: tolN = int(tolNentry.get()) #在Entry部件调用get方法, except: tolN = 10 print "enter Integer for tolN" tolNentry.delete(0,END) tolNentry.insert(0,"10") try: tolS = float(tolSentry.get()) except: tolS = 1.0 print "enter Integer for tolS" tolSentry.delete(0,END) tolSentry.insert(0,"1.0") return tolN,tolS def drawNewTree(): #有人点击ReDraw按钮时就会调用该函数 tolN,tolS = getInputs() #得到输入框的值 reDraw(tolS,tolN) #调用reDraw函数 root = Tk() reDraw.f = Figure(figsize=(5,4),dpi=100) reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root) reDraw.canvas.show() reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)
效果:
Figure 9-8: 用treeExplore的GUI构建的回归树
Figure 9-9: 模型树。参数为tolN=1,tolS=0
9.8小结
数据集中输入数据和目标变量呈非线性关系,可使用树结构来对预测值分段,包括分段常数或分段直线。叶节点使用分段常数则为回归树,若为线性回归方程则为模型树。