CART算法-《机器学习实战》总结

import matplotlib.pyplot as plt
import numpy as np
import random
from numpy import *
from sklearn import *
from sklearn.tree import *
from operator import  *
'''
根据特征维度和特征值,分隔数据集
'''
def buildSpiltDataSet(dataSet,feature,value):
    #分隔数据集
    #
    mat0=dataSet[nonzero(dataSet[:,feature]>value)[0]]
    mat1=dataSet[nonzero(dataSet[:,feature]<=value)[0]]

    #print("拆分矩阵类型:\n",type(mat0))


    #print('mat0:\n',mat0)
    #print('mat1:\n',mat1)

    return  mat0,mat1


# 求给定数据集的线性方程
def linearSolve(dataSet):
    m,n = np.shape(dataSet)
    X = np.mat(np.ones((m,n))); # 第一行补1,线性拟合要求
    Y = np.mat(np.ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1];
    Y = dataSet[:,-1] # 数据最后一列是y
    xTx = X.T*X
    if np.linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing dur')
    ws = xTx.I * (X.T * Y) # 公式推导较难理解
    return ws,X,Y


# 求线性方程的参数
def modelLeaf(dataSet):
    ws, X, Y = linearSolve(dataSet)
    return ws


# 预测值和y的方差
def modelErr(dataSet):
    ws, X, Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(np.power(Y - yHat, 2))


'''
选择最优的切分点
返回最优切换点的维度和维度值
'''

#处理叶子节点,注意:最后一列才是目标值
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

#计算数据集的总方差
def regVar(dataSet):
    return var(dataSet[:,-1])*(shape(dataSet)[0])

#返回最优切分点的维度和维度值
def chooseBestSplit(dataSet,pregLeaf=regLeaf,pregVar=regVar,ops=(1,4)):
    #误差范围
    tolS=ops[0]
    #数据集的大小范围
    tolN=ops[1]
    #print(len(dataSet[:,-1].T.tolist()[0]))
    #如果集合只有一个元素,说明是叶子节点,返回均值即可
    if len(set(dataSet[:,-1].T.tolist()[0]))==1 :

        return None,pregLeaf(dataSet),None

    m,n=shape(dataSet)

    #最优拆分值
    bestSpValue=inf
    #最优拆分维度
    bestFea=0.0
    #最优拆分维度值
    bestFeaValue=0.0
    #拆分前计算总的数据集的总方差
    S=pregVar(dataSet)
    #遍历列和行的值,进行划分,注意:这里是[0,n-2],不包含最后一列
    for j in range(n-1):

        targetColumn=dataSet[:,j]

        #targetColumnType=type(targetColumn)

        for colValue in set(targetColumn.T.tolist()[0]):

            #print('list value:\n',colValueMat.tolist())
            #colValue=colValueMat.tolist()[0][0]

            mat0,mat1=buildSpiltDataSet(dataSet, j, colValue)

            #print(shape(mat0)[0])

            #修改判定数据集的范围,如果小于用户指定的数据集的大小,就不继续分隔,
            #if shape(mat0)[0] ==0 or shape(mat1)[0] ==0 :
            #   continue

            if shape(mat0)[0] <tolN or shape(mat1)[0] <tolN :
                continue
            #两个划分区域的总方差之和,注意计算的是最后一列的
            currentVar = regVar(mat0) + regVar(mat1)

            if currentVar < bestSpValue:
                bestSpValue = currentVar
                bestFea = j
                bestFeaValue = colValue

   # print('最优总方差为:\n',bestSpValue)
   # print('最优划分维度为:\n',bestFea)
    #print('最优划分维度值为:\n', bestFeaValue)

    #如果误差减少并不大,则返回数据集的平均值,不继续拆分,因为数据集已经非常有序了
    if (S-bestSpValue)<tolS :
        return None,pregLeaf(dataSet),None

    return bestFea,bestFeaValue,bestSpValue

#创建CARTdef createCARTTree(dataSet,pregLeaf=regLeaf,pregVar=regVar,ops=(1,4)):

    #选择最优分类维度
    bestFea, bestFeaValue, bestSpValue=chooseBestSplit(dataSet,pregLeaf,pregVar,ops)

    # 如果是叶子节点,直接返回
    if None == bestFea :

        return bestFeaValue

    #使用字典记录相关切分数据
    cartDict={}
    #分隔维度
    cartDict['spInd']=bestFea
    #分隔维度值
    cartDict['spVal']=bestFeaValue

    #使用上面的评估,进行分隔

    leftMat,rigthMat=  buildSpiltDataSet(dataSet,bestFea,bestFeaValue)

    #分别创建左子树 右子树
    cartDict['left']=createCARTTree(leftMat,pregLeaf,pregVar,ops)
    cartDict['right']=createCARTTree(rigthMat,pregLeaf,pregVar,ops)

    return  cartDict




def loadDataSet():
    dataSet=[
        [1.5,5.56],
        [2.5,5.7],
        [3.5, 5.91],
        [4.5, 6.4],
        [5.5,6.8],
        [6.5, 7.05],
        [7.5, 8.9],
        [8.5, 8.7],
        [9.5, 9],
        [10.5, 9.05]
         ]
    return dataSet

def loadDataSet2(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        #fltLine = map(float,curLine) #map all elements to float()
        #print('type=\n',type(curLine))
        dataMat.append(list(map(lambda x:float(x),curLine)))
    return dataMat


#树的剪枝
#判定左子节点和右子节点,是否需要合并;标准是合并后是否小于合并前的误差值

#判定一个节点是否为树节点
def isTree(dataSet):

    if  type(dataSet).__name__ == 'dict':

        return True
    else:
        return False

#获取整棵树的平均值

def getMean(tree):
    #如果左子树是树节点,递归获取左子树的平均值
    if isTree(tree['left']):
       tree['left']=getMean(tree['left'])
    #递归获取右子树的平均值
    if isTree(tree['right']):
       tree['right']=getMean(tree['right'])

    return (tree['left']+tree['right'])/2.0



#传入最优分隔点和测试数据集,使用测试数据集优化当前的最优分隔点,方法是对比合并叶子节点前后误差的变化
def prune(tree,testDataSet):

    #如果传入的测试数据集已经被划分空了,则返回树的权值????注意这里不是整颗树,而是当前符合一定范围的树,来自于划分的大小,也就是创建tree的时候那时候的划分大小
    if shape(testDataSet)[0] == 0 :
        return getMean(tree)

    #如果左子树或者右子树是树节点,说明需要继续划分,很简单,因为当初tree就是这样生成的
    #使用生成tree的划分,划分当前集合
    if isTree(tree['left']) or isTree(tree['right']) :

        lTree,rTree=buildSpiltDataSet(testDataSet,tree['spInd'],tree['spVal'])

    #如果左子树非叶子节点,说明需要递归裁剪,对左子树递归裁剪
    if isTree(tree['left']) :

        tree['left']=prune(tree['left'],lTree)

    #递归对右子树进行裁剪
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'], rTree)

    #如果左子树和右子树都是叶子节点,查看是否可以合并
    if not isTree(tree['left']) and not isTree(tree['right']):

         #拆分测试数据集
         lTree, rTree = buildSpiltDataSet(testDataSet, tree['spInd'], tree['spVal'])

         #计算没有合并的总方差

         errorNoMerge=sum(power(lTree[:,-1]-tree['left'],2))+sum(power(rTree[:,-1]-tree['right'],2))

         #计算合并的总方差

         mergeTreeMean=(tree['left']+tree['right'])/2.0

         errorMerge=sum(power(testDataSet[:,-1]-mergeTreeMean,2))

         #如果合并的误差更新,那就合并,返回合并后的平均值
         if errorMerge<errorNoMerge:

             return mergeTreeMean

         #否则,不合并,直接返回原来的tree

             return tree

    else:
         return tree


#画散点图
def plot(dataSet):

    x1=dataSet[:,0]
    x2=dataSet[:,1]
    fig=plt.figure('散点图')
    ax=fig.add_subplot(111)

    ax.scatter(list(x1),list(x2),s=3,c='red',marker='s')

    model1 = DecisionTreeRegressor(max_depth=3)



    model1.fit(dataSet[:, 0], dataSet[:, 1])

    minV=dataSet[:, 1].min()
    maxV=dataSet[:,1].max()

    x_test = arange(minV, maxV, 0.01).reshape(-1, 1)

    y = model1.predict(x_test)

    #plt.plot(x_test,y,color='green',label='tree regression',linewidth=2)

    plt.show()


#排序
def sort(dataSet):


    return sorted(dataSet,key=itemgetter(0))















fileName='D:\software\python\sourcecode_and_data\MLiA_SourceCode\machinelearninginaction\Ch09\exp2.txt'
#fileName='D:\software\python\sourcecode_and_data\MLiA_SourceCode\machinelearninginaction\Ch09\ex00.txt'

#dataSet=loadDataSet()
dataSet=loadDataSet2(fileName)

#dataSet=sort(dataSet)



cartDict=createCARTTree(mat(dataSet), regLeaf, regVar, (0.3, 1))


#print('数据集:\n',mat(dataSet))
print("创建的回归树为:\n",cartDict)


#创建模型树


modelDict=createCARTTree(mat(dataSet), modelLeaf, modelErr, (0.3, 1))


#print('数据集:\n',mat(dataSet))
print("创建的模型树为:\n",modelDict)



#treeMean=getMean(cartDict)

#print('整棵树的平均值:\n',treeMean)

#testfileName='D:\software\python\sourcecode_and_data\MLiA_SourceCode\machinelearninginaction\Ch09\ex2test.txt'

#测试数据的目的在于,修正训练数据产生的模型
#testDataSet=loadDataSet2(testfileName)
#pruneCartDict=prune(cartDict,testDataSet)

#print("使用测试数据集修正的CART树为:\n",pruneCartDict)


plot(mat(dataSet))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值