#-*-coding:utf-8-*-
from numpy import *
# 9.1 CART算法的实现代码
# createTree()树构建函数(数据集,其他三个可选参数:建立叶结点的函数、误差计算函数、包含树构建所需其他参数的元组)
# 是一个递归函数
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float, curLine)
dataMat.append(fltLine)
return dataMat
# 参数:数据集合,待切分的特征,和该特征的某个值
# 在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回
def binSplitDataSet(dataSet,feature,value):
mat0 = dataSet[nonzero(dataSet[:, feature]>value)[0],:]
mat1 = dataSet[nonzero(dataSet[:, feature]<=value)[0],:]
#mat0 = dataSet[nonzero(dataSet[:,feature]>value)[0],:][0]
#mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0],:][0]
return mat0, mat1
# 9.2 回归树的切分函数
# regLeaf()函数,负责生成叶节点,
# 当chooseBestSplit()切分函数确定不再对数据进行切分时,调用该函数来得到叶节点的模型。
# 在回归树中,该模型其实就是目标变量的均值
def regLeaf(dataSet):
return mean(dataSet[:,-1])
# regErr()函数,在给定数据上计算目标变量的平方误差。
# 直接调用var()函数(当然也可以先计算出均值,然后计算每个差值在平方)
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0]
# 首先由chooseBestSplit()切分函数将数据分成两部分,chooseBestSplit()切分函数将返回none值和某类模型的值
# 如果找不到一个“好”的二元切分,该函数返回none的同时调用createTree()树构建函数来产生叶节点,叶节点的值也将返回none
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
# 一开始为ops设定了tolS、tolN两个值,用户指定参数,用于控制函数的停止时机
tolS = ops[0] # 容许的误差下降值
tolN = ops[1] # 切分的最少样本数
# chooseBestSplit()切分函数会统计不同剩余特征值的数目,如果为1,那么就不需要再切分而是直返回
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
# 然后函数计算当前数据集的大小和误差
m,n = shape(dataSet)
# 该误差S将用于与新切分误差进行对比,来检查新切分能否降低误差
S = errType(dataSet)
bestS = inf
bestIndex = 0
bestValue = 0
# 在所有可能的特征及
for featIndex in range(n-1):
# 其可能取值上遍历
# for splitVal in set(dataSet[:,featIndex]):
for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
mat0 , mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS<bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
# 接下来将会看到chooseBestSplit()切分函数中有三种情况不会切分,而是直接创造叶节点。
# 如果切分数据集后效果提升不够大,那么就不应该进行切分操作而直接创建叶节点
if (S - bestS)<tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
# 另外还需检查两个切分后的子集大小,如果某个子集的大小小于用户定义参tolN,那么也不应切分
if (shape(mat0)[0]<tolN) or (shape(mat1)[0]<tolN):
return None, leafType(dataSet)
# 如果找到一个“好”的切分方式,则返回特征编号和切分特征值
return bestIndex, bestValue
# createTree()树构建函数(数据集,其他三个可选参数:建立叶结点的函数、误差计算函数、包含树构建所需其他参数的元组)
# 是一个递归函数
def createTree(dataSet, leafType = regLeaf, errType = regErr, ops=(1,4)):
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
# 如果构建的是回归树,该模型是个常数;如果是模型树,该模型是个线性方程
if feat == None:
return val
retTree = { }
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet,feat,val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType,errType,ops)
return retTree
# 树剪枝
# 9.3 回归树剪枝函数
# isTree()用于测试输入变量是否是一棵树,返回布尔类型的结果
def isTree(obj):
return (type(obj).__name__ == 'dict')
# getMean()是一个递归函数,从上往下遍历树直到叶节点为止
# 如果找到两个叶节点则计算它们的平均值
# 该函数对树进行塌陷处理(即返回树平均值),在prune()函数中调用该函数时应明确这一点
def getMean(tree):
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
# 主函数是prune()函数,参数:待剪枝的树与剪枝所需的测试数据testData
def prune(tree, testData):
# 首先需确认测试集是否为空
if shape(testData)[0] == 0:
return getMean(tree)
# 非空则反复递归调用函数prune()对数据进行切分
if (isTree(tree['right']) or isTree(tree['left'])) :
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'],lSet)
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,1] - treeMean,2))
if errorMerge < errorNoMerge:
print("merging")
return treeMean
else: return tree
else: return tree
# 9.4 模型树的叶节点生成函数
# linearSolve()函数,会被其他两个函数调用
# 其主要功能是将数据集格式化成目标变量Y和自变量X
def linearSolve(dataSet):
m,n = shape(dataSet)
X = mat(ones((m,n)))
Y = mat(ones((m,1)))
X[:,1:n] = dataSet[:,0:n-1]
Y = dataSet[:,-1]
xTx = X.T*X
if linalg.det(xTx) == 0.0:
raise NameError("this matrix is singular, cannot do inverse,try increasing the second value of ops")
ws = xTx.I * (X.T * Y)
return ws, X, Y
# modelLesf()函数,当数据不在需要切分时它负责乘胜叶节点的模型
def modelLesf(dataSet):
# 该函数在数据集上调用linearSolve()函数,并返回回归系数ws
ws, X, Y = linearSolve(dataSet)
return ws
# modelErr()函数,在给定的数据集上计算误差,会被chooseBestSplit()函数调用来找到最佳的切分
def modelErr(dataSet):
# 该函数在数据集上调用linearSolve()函数,之后返回yhat和Y之间的平方误差
ws, X, Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y -yHat,2))
# 9.5 用树回归进行预测的代码
# 要对回归树叶节点进行预测,就调用regTreeEval()函数,对输入数据进行格式化处理
def regTreeEval(model, inDat):
return float(model)
# 要对模型树叶节点进行预测,就调用modelTreeEval()函数,对输入数据进行格式化处理
def modelTreeEval(model, inDat):
n = shape(inDat)[1]
# 在原数据矩阵上增加第0列,然后计算并返回预测值
X = mat(ones(1,n+1))
X[:,1:n+1] = inDat
return float(X*model)
# 对于输入的单个数据点或者行向量,函数treeForeCast()会返回一个浮点值。
# 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
# 调用函数treeForeCast()时需要指定树的类型,以便在叶节点上能够调用合适的模型
def treeForeCast(tree, inData, modelEval = regTreeEval):
if not isTree(tree):
return modelEval(tree, inData)
if inData(tree['spInd']) > tree['spVal']:
if isTree(tree['left']):
return treeForeCast(tree['left'],inData,modelEval)
else:
return modelEval(tree['left'], inData)
else:
if isTree(tree['right']):
return treeForeCast(tree['right'], inData,modelEval)
else:
return modelEval(tree['right'], inData)
def createForeCast(tree, testData, modelEval = regTreeEval):
m = len(testData)
yHat = mat(zeros((m,1)))
for i in range(m):
yHat[i,0 ] = treeForeCast(tree, mat(testData[i]), modelEval)
return yHat
# 9.6 用于构建树管理器界面的Thinter小部件
from numpy import *
from Tkinter import *
import regTree
def reDraw(tolS,tolN):
pass
def drawNewTree():
pass
root = Tk()
Label(root, text = "Plot Place Holder").grid(row = 0, columnspan = 2)
Label(root ,text = "tolN").grid(row=1, column=0)
tolNentry = Entry(root)
tolNentry.grid(row=1, column=1)
tolNentry.insert(0,'10')
Label(root, text = "tolS").grid(row =2,column = 0)
tolSentry = Entry(root)
tolSentry.grid(row=2, column=1)
tolSentry.insert(0,'1.0')
Button(root, text="ReDraw", command = drawNewTree).grid(row=1,column=2, rowspan = 3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="model tree", varible = chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
reDraw.rawDat = mat(regTree.loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:,0],max(reDraw.rawDat[:,0],0.01)))
reDraw(1.0,10)
root.mainloop()
MLiA笔记_树回归
最新推荐文章于 2021-05-22 16:51:10 发布