import matplotlib.pyplot as plt import numpy as np import random from numpy import * from sklearn import * from sklearn.tree import * from operator import * ''' 根据特征维度和特征值,分隔数据集 ''' def buildSpiltDataSet(dataSet,feature,value): #分隔数据集 # mat0=dataSet[nonzero(dataSet[:,feature]>value)[0]] mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0]] #print("拆分矩阵类型:\n",type(mat0)) #print('mat0:\n',mat0) #print('mat1:\n',mat1) return mat0,mat1 # 求给定数据集的线性方程 def linearSolve(dataSet): m,n = np.shape(dataSet) X = np.mat(np.ones((m,n))); # 第一行补1,线性拟合要求 Y = np.mat(np.ones((m,1))) X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1] # 数据最后一列是y xTx = X.T*X if np.linalg.det(xTx) == 0.0: raise NameError('This matrix is singular, cannot do inverse,\n\ try increasing dur') ws = xTx.I * (X.T * Y) # 公式推导较难理解 return ws,X,Y # 求线性方程的参数 def modelLeaf(dataSet): ws, X, Y = linearSolve(dataSet) return ws # 预测值和y的方差 def modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return sum(np.power(Y - yHat, 2)) ''' 选择最优的切分点 返回最优切换点的维度和维度值 ''' #处理叶子节点,注意:最后一列才是目标值 def regLeaf(dataSet): return mean(dataSet[:,-1]) #计算数据集的总方差 def regVar(dataSet): return var(dataSet[:,-1])*(shape(dataSet)[0]) #返回最优切分点的维度和维度值 def chooseBestSplit(dataSet,pregLeaf=regLeaf,pregVar=regVar,ops=(1,4)): #误差范围 tolS=ops[0] #数据集的大小范围 tolN=ops[1] #print(len(dataSet[:,-1].T.tolist()[0])) #如果集合只有一个元素,说明是叶子节点,返回均值即可 if len(set(dataSet[:,-1].T.tolist()[0]))==1 : return None,pregLeaf(dataSet),None m,n=shape(dataSet) #最优拆分值 bestSpValue=inf #最优拆分维度 bestFea=0.0 #最优拆分维度值 bestFeaValue=0.0 #拆分前计算总的数据集的总方差 S=pregVar(dataSet) #遍历列和行的值,进行划分,注意:这里是[0,n-2],不包含最后一列 for j in range(n-1): targetColumn=dataSet[:,j] #targetColumnType=type(targetColumn) for colValue in set(targetColumn.T.tolist()[0]): #print('list value:\n',colValueMat.tolist()) #colValue=colValueMat.tolist()[0][0] mat0,mat1=buildSpiltDataSet(dataSet, j, colValue) #print(shape(mat0)[0]) #修改判定数据集的范围,如果小于用户指定的数据集的大小,就不继续分隔, #if shape(mat0)[0] ==0 or shape(mat1)[0] ==0 : # continue if shape(mat0)[0] <tolN or shape(mat1)[0] <tolN : continue #两个划分区域的总方差之和,注意计算的是最后一列的 currentVar = regVar(mat0) + regVar(mat1) if currentVar < bestSpValue: bestSpValue = currentVar bestFea = j bestFeaValue = colValue # print('最优总方差为:\n',bestSpValue) # print('最优划分维度为:\n',bestFea) #print('最优划分维度值为:\n', bestFeaValue) #如果误差减少并不大,则返回数据集的平均值,不继续拆分,因为数据集已经非常有序了 if (S-bestSpValue)<tolS : return None,pregLeaf(dataSet),None return bestFea,bestFeaValue,bestSpValue #创建CART树 def createCARTTree(dataSet,pregLeaf=regLeaf,pregVar=regVar,ops=(1,4)): #选择最优分类维度 bestFea, bestFeaValue, bestSpValue=chooseBestSplit(dataSet,pregLeaf,pregVar,ops) # 如果是叶子节点,直接返回 if None == bestFea : return bestFeaValue #使用字典记录相关切分数据 cartDict={} #分隔维度 cartDict['spInd']=bestFea #分隔维度值 cartDict['spVal']=bestFeaValue #使用上面的评估,进行分隔 leftMat,rigthMat= buildSpiltDataSet(dataSet,bestFea,bestFeaValue) #分别创建左子树 右子树 cartDict['left']=createCARTTree(leftMat,pregLeaf,pregVar,ops) cartDict['right']=createCARTTree(rigthMat,pregLeaf,pregVar,ops) return cartDict def loadDataSet(): dataSet=[ [1.5,5.56], [2.5,5.7], [3.5, 5.91], [4.5, 6.4], [5.5,6.8], [6.5, 7.05], [7.5, 8.9], [8.5, 8.7], [9.5, 9], [10.5, 9.05] ] return dataSet def loadDataSet2(fileName): #general function to parse tab -delimited floats dataMat = [] #assume last column is target value fr = open(fileName) for line in fr.readlines(): curLine = line.strip().split('\t') #fltLine = map(float,curLine) #map all elements to float() #print('type=\n',type(curLine)) dataMat.append(list(map(lambda x:float(x),curLine))) return dataMat #树的剪枝 #判定左子节点和右子节点,是否需要合并;标准是合并后是否小于合并前的误差值 #判定一个节点是否为树节点 def isTree(dataSet): if type(dataSet).__name__ == 'dict': return True else: return False #获取整棵树的平均值 def getMean(tree): #如果左子树是树节点,递归获取左子树的平均值 if isTree(tree['left']): tree['left']=getMean(tree['left']) #递归获取右子树的平均值 if isTree(tree['right']): tree['right']=getMean(tree['right']) return (tree['left']+tree['right'])/2.0 #传入最优分隔点和测试数据集,使用测试数据集优化当前的最优分隔点,方法是对比合并叶子节点前后误差的变化 def prune(tree,testDataSet): #如果传入的测试数据集已经被划分空了,则返回树的权值????注意这里不是整颗树,而是当前符合一定范围的树,来自于划分的大小,也就是创建tree的时候那时候的划分大小 if shape(testDataSet)[0] == 0 : return getMean(tree) #如果左子树或者右子树是树节点,说明需要继续划分,很简单,因为当初tree就是这样生成的 #使用生成tree的划分,划分当前集合 if isTree(tree['left']) or isTree(tree['right']) : lTree,rTree=buildSpiltDataSet(testDataSet,tree['spInd'],tree['spVal']) #如果左子树非叶子节点,说明需要递归裁剪,对左子树递归裁剪 if isTree(tree['left']) : tree['left']=prune(tree['left'],lTree) #递归对右子树进行裁剪 if isTree(tree['right']): tree['right'] = prune(tree['right'], rTree) #如果左子树和右子树都是叶子节点,查看是否可以合并 if not isTree(tree['left']) and not isTree(tree['right']): #拆分测试数据集 lTree, rTree = buildSpiltDataSet(testDataSet, tree['spInd'], tree['spVal']) #计算没有合并的总方差 errorNoMerge=sum(power(lTree[:,-1]-tree['left'],2))+sum(power(rTree[:,-1]-tree['right'],2)) #计算合并的总方差 mergeTreeMean=(tree['left']+tree['right'])/2.0 errorMerge=sum(power(testDataSet[:,-1]-mergeTreeMean,2)) #如果合并的误差更新,那就合并,返回合并后的平均值 if errorMerge<errorNoMerge: return mergeTreeMean #否则,不合并,直接返回原来的tree return tree else: return tree #画散点图 def plot(dataSet): x1=dataSet[:,0] x2=dataSet[:,1] fig=plt.figure('散点图') ax=fig.add_subplot(111) ax.scatter(list(x1),list(x2),s=3,c='red',marker='s') model1 = DecisionTreeRegressor(max_depth=3) model1.fit(dataSet[:, 0], dataSet[:, 1]) minV=dataSet[:, 1].min() maxV=dataSet[:,1].max() x_test = arange(minV, maxV, 0.01).reshape(-1, 1) y = model1.predict(x_test) #plt.plot(x_test,y,color='green',label='tree regression',linewidth=2) plt.show() #排序 def sort(dataSet): return sorted(dataSet,key=itemgetter(0)) fileName='D:\software\python\sourcecode_and_data\MLiA_SourceCode\machinelearninginaction\Ch09\exp2.txt' #fileName='D:\software\python\sourcecode_and_data\MLiA_SourceCode\machinelearninginaction\Ch09\ex00.txt' #dataSet=loadDataSet() dataSet=loadDataSet2(fileName) #dataSet=sort(dataSet) cartDict=createCARTTree(mat(dataSet), regLeaf, regVar, (0.3, 1)) #print('数据集:\n',mat(dataSet)) print("创建的回归树为:\n",cartDict) #创建模型树 modelDict=createCARTTree(mat(dataSet), modelLeaf, modelErr, (0.3, 1)) #print('数据集:\n',mat(dataSet)) print("创建的模型树为:\n",modelDict) #treeMean=getMean(cartDict) #print('整棵树的平均值:\n',treeMean) #testfileName='D:\software\python\sourcecode_and_data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2test.txt' #测试数据的目的在于,修正训练数据产生的模型 #testDataSet=loadDataSet2(testfileName) #pruneCartDict=prune(cartDict,testDataSet) #print("使用测试数据集修正的CART树为:\n",pruneCartDict) plot(mat(dataSet))
CART算法-《机器学习实战》总结
最新推荐文章于 2024-05-10 17:11:38 发布