机器学习实战学习笔记5-树回归

CART(Classification And Regression Trees 分类回归树)

https://blog.csdn.net/gzj_1101/article/details/78355234
https://blog.csdn.net/e15273/article/details/80463079
https://blog.csdn.net/xgxyxs/article/details/79436235 以上帖子中有实例

树回归的一般方法

(1)收集数据:可以使用任何方法收集数据
(2)准备数据:需要数值型数据,标称型数据将被转换成二值型数据
(3)分析数据: 给出数据的二维可视化显示结果,以字典方式生成树
(4)训练算法:大部分时间都花费在叶节点树模型的构建上
(5)测试算法:使用测试数据上的 R 2 R^2 R2 值来分析模型的效果
(6)使用算法:使用训练出的树做预测,预测结果还可以用来做很多事情
CART算法:

CART算法正好适用于连续型特征。CART算法使用二元切分法来处理连续型变量。而使用二元切分法则易于对树构建过程进行调整以处理连续型特征。具体的处理方法是:如果特征值大于给定值就走左子树,否则就走右子树。

CART算法有两步:
决策树生成:递归地构建二叉决策树的过程,基于训练数据集生成决策树,生成的决策树要尽量大;自上而下从根开始建立节点,在每个节点处要选择一个最好的属性来分裂,使得子节点中的训练集尽量的纯。不同的算法使用不同的指标来定义”最好”:
决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准。
CART算法的决策树生成实现过程如下:
(1)使用CART算法选择特征
(2)根据特征切分数据集合
(3)构建树

伪代码

找到最佳的待切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree() 方法
在左子树调用createTree() 方法 
程序清单9-1 标准回归函数和数据导入函数
from numpy import *
import numpy as np

import matplotlib.pyplot as plt


def loadDataSet(fileName):
    """
    读取tab键分隔符的文件将每行的内容保存成一组浮点数
    """
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))
        '''
        TypeError: unsupported operand type(s) for /: ‘map‘ and ‘int‘ 
        修改loadDataSet函数某行为fltLine = list(map(float,curLine)),因为python3中map的返回值变了,所以要加list() 
        '''
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet,feature,value):
    '''
    参数:数据集合,待切分的特征和该特征的某个值
    在给定特征和特征值,函数通过数组过滤方式将数据切分得到两个子集
    '''
    # np.nonzero(a),返回数组a中非零元素的索引值数组 
    # np.nonzero(dataSet[:, feature] > value)[0]  = 1,行下标为1 第二行 
    # 下面一行代码表示mat0=dataSet[1,:]即第二行所有列
    #print(dataSet[:,feature])
    #print(nonzero(dataSet[:, feature] > value)[0])
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    # np.nonzero(dataSet[:, feature] <= value)[0],表示取第一列中小于0.5的数的索引值,
    # 下面代码表示mat1=dataSet[1,:]即第一、三、四行所有列
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]    
    return mat0,mat1

testMat = mat(eye(4))
testMat
matrix([[ 1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])
mat0,mat1 = binSplitDataSet(testMat, 1, 0.5)
print(mat0)
[[ 0.  1.  0.  0.]]
print(mat1)
[[ 1.  0.  0.  0.]
 [ 0.  0.  1.  0.]
 [ 0.  0.  0.  1.]]

def plotDataSet(filename):

    '''
    函数说明:绘制数据集
    Parameters:
        filename - 文件名
    Returns:
        无
    '''
    dataMat = loadDataSet(filename)                                        #加载数据集
    n = len(dataMat)                                                    #数据个数
    xcord = []; ycord = []                                                #样本点
    for i in range(n):
        xcord.append(dataMat[i][0]); ycord.append(dataMat[i][1])        #样本点
    fig = plt.figure()
    ax = fig.add_subplot(111)                                            #添加subplot
    ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5)                #绘制样本点
    plt.title('DataSet')                                                #绘制title
    plt.xlabel('X')
    plt.show()

filename = 'ex00.txt'
plotDataSet(filename)

在这里插入图片描述

找到数据集切分的最佳位置的伪代码

对每个特征:
   对每个特征值:
      将数据集切分成两份:
      计算切分的误差
      如果当前误差小于当前最小误差:那么将当前切分设定为最佳切分并更新最小误差,返回最佳切分的特征和阈值
程序清单9-2 标准回归函数和数据导入函数
# 负责生成叶节点,当chooseBestSplit()函数确定不再对数据进行切分时,
def regLeaf(dataSet):
    return mean(dataSet[:,-1]) # 将调用该regLeaf()函数来得到叶节点的模型,在回归树中,该模型其实就是目标变量的均值

# 误差估计函数,该函数在给定的数据上计算目标变量的平方误差,
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0] #这里直接调用均方差函数var,因为这里需要返回的是总方差,所以要用均方差乘以数据集中样本的个数

##回归树切分函数,构建回归树的核心函数。目的:找出数据的最佳二元切分方式。
#如果找不到一个“好”的二元切分,该函数返回None并同时调用createTree()方法来产生叶节点,叶节点的值也将返回None。
# 如果找到一个“好”的切分方式,则返回特征编号和切分特征值。
# 最佳切分就是使得切分后能达到最低误差的切分。
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr, ops = (1,4)):
    """
    构建回归树的核心函数,找到最佳的二元切分方式   
    leafType 是对创建叶节点的函数的引用
    errType 是对总方差的计算函数的引用
    ops 是一个用户自定义的参数构成的元组,用以完成树的构建
    """
    tolS = ops[0]; tolN = ops[1] 
    #tolS 容许的误差下降值
    #tolN 切分的最少样本数
    # 如果剩余特征值的数目为1,那么就不再切分而返回
    #如果当前所有值相等,则退出。(根据set的特性)
    if len(set(dataSet[:,-1].T.tolist()[0])) ==1:
        return None,leafType(dataSet)   #如果所有值相等则退出
    # 当前数据集的大小
    m, n = np.shape(dataSet)
    # 当前数据集的误差
    # 计算数据集最后一列的特征总方差。
    #默认最后一个特征为最佳切分特征,计算其误差估计
    S = errType(dataSet)
    #float("inf") 表示正负无穷
    #分别为最佳误差,最佳特征切分的索引值,最佳特征值
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        #for splitVal in set(dataSet[:,featIndex]):
        for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):
            '''
            TypeError: unhashable type: ‘matrix’ 
            修改chooseBestSplit函数某行为:for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]): matrix类型不能被hash。
            '''
            #根据特征和特征值切分数据集
            mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
             #如果数据少于tolN,则退出
            if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0] < tolN):
                continue
             #计算误差估计
            newS = errType(mat0) + errType(mat1)
            #如果误差估计更小,则更新特征索引值和特征值
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果切分数据集后效果提升不够大,那么就不应该进行切分操作而直接创建叶节点             
    if (S - bestS) < tolS:
        return None, leafType(dataSet)  #如果误差不大则退出
    mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
    # 检查切分后的子集大小,如果某个子集的大小小于用户定义的参数tolN,那么也不应切分。
    if (np.shape(mat0)[0] < tolN) or (np.shape(mat1)[0]) < tolN:
        return None, leafType(dataSet)  #如果切分的数据集很小则退出
    # 如果前面的这些终止条件都不满足,那么就返回切分特征和特征值。
    return bestIndex, bestValue

# dataSet: 数据集合
# leafType: 给出建立叶节点的函数
# errType: 误差计算函数
# ops: 包含树构建所需其他参数的元组
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 将数据集分成两个部分,若满足停止条件,chooseBestSplit将返回None和某类模型的值
    # 若构建的是回归树,该模型是个常数。若是模型树,其模型是一个线性方程。
    # 若不满足停止条件,chooseBestSplit()将创建一个新的Python字典,并将数据集分成两份,
    # 在这两份数据集上将分别继续递归调用createTree()函数
    #选择最佳切分特征和特征值
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    #r如果没有特征,则返回特征值
    if feat == None: return val 
     #回归树
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
     #分成左数据集和右数据集
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    #创建左子树和右子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree  

myDat = loadDataSet('ex00.txt')
myMat = mat(myDat)
createTree(myMat)
{'left': 1.0180967672413792,
 'right': -0.044650285714285719,
 'spInd': 0,
 'spVal': 0.48813}

def plotDataSet1(filename):

    '''
    函数说明:绘制数据集
    Parameters:
        filename - 文件名
    Returns:
        无
    '''
    dataMat = loadDataSet(filename)                                        #加载数据集
    n = len(dataMat)                                                    #数据个数
    xcord = []; ycord = []                                                #样本点
    for i in range(n):
        xcord.append(dataMat[i][1]); ycord.append(dataMat[i][2])        #样本点
    fig = plt.figure()
    ax = fig.add_subplot(111)                                            #添加subplot
    ax.scatter(xcord, ycord, s = 20, c = 'blue',alpha = .5)                #绘制样本点
    plt.title('DataSet')                                                #绘制title
    plt.xlabel('X')
    plt.show()
    
filename = 'ex0.txt'
plotDataSet1(filename)

在这里插入图片描述

myDat = loadDataSet('ex0.txt')
myMat1 = np.mat(myDat)
createTree(myMat1)
{'left': {'left': {'left': 3.9871631999999999,
   'right': 2.9836209534883724,
   'spInd': 1,
   'spVal': 0.797583},
  'right': 1.980035071428571,
  'spInd': 1,
  'spVal': 0.582002},
 'right': {'left': 1.0289583666666666,
  'right': -0.023838155555555553,
  'spInd': 1,
  'spVal': 0.197834},
 'spInd': 1,
 'spVal': 0.39435}
feat, val = chooseBestSplit(myMat, regLeaf, regErr, (1, 4))
print(feat)
print(val)
0
0.48813
#https://blog.csdn.net/jiaoyangwm/article/details/79631480  详见该贴  呆呆的猫

#https://blog.csdn.net/sinat_17196995/article/details/69621687   修行的猫 错误调试讲解清楚
#createTree(myMat,ops=(0,1))
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
createTree(myMat2,ops=(10000,4))
{'left': 101.35815937735848,
 'right': -2.6377193297872341,
 'spInd': 0,
 'spVal': 0.499171}
剪枝技术

一棵树如果节点过多,表示该模型可能对数据进行了“过拟合”。通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。

剪枝分为预剪枝(prepruning)和后剪枝(postpruning)。

预剪枝是
 指在决策树生成过程中,对每个节点在划分前先进行估计,若当前节点的划分不能带来决策树泛化性能的提升,则停止划分并将当前节点记为叶节点(上面的程序已经使用了预剪枝);
后剪枝
是先在训练集生成一棵完整的决策树,然后自底向上地对非叶节点进行考察,若将该节点对应的子树替换为叶节点能带来决策树泛化性能提升,则将该子树替换为叶节点
后减枝prune()函数的伪代码:
基于已有的树切分测试数据:
      如果存在任一子集是一棵树,则在该子集递归剪枝过程
      计算将当前两个叶节点合并后的误差
      计算不合并的误差
      如果合并会降低误差的话,就将叶节点合并
程序清单9-3 回归树剪枝函数
#后剪枝操作,利用训练集创建好树之后。根据测试数据集进行后剪枝处理
#测试输入对象是否为树
def isTree(obj):
    #树结构是用python中的字典存储的,所以只用判断该对象类型是否为dict
    return type(obj).__name__ == 'dict'
#在进行树塌陷处理时,需要求得树节点的平均值
def getMean(tree):
    # 如果右子树是树的话,求得右子树平均值,并将此值置为右子树的值
    if (isTree(tree['right'])): tree['right'] = getMean(tree['right'])
    #如果左子树是树的话,求得左子树平均值,并将此值置为左子树的值
    if (isTree(tree['left'])): tree['left'] = getMean(tree['left'])
    #返回该节点的平均值
    return (tree['left'] + tree['right'])/2.0

#树剪枝函数
def prune(tree,testSet):
    #如果测试集样本数为空,则对树进行塌陷处理,返回树节点的均值,剪枝到最后测试集空的情况
    if (np.shape(testSet)[0]==0):
        return getMean(tree)
    #如果测试集不为空,如果左子树或者右子树还是树的话,则对测试集进行切分,将切分数据传入剪枝函数中对树结构进行递归剪枝
    #获取到下一次递归中的测试数据集
    if( isTree(tree['right'])or isTree(tree['left']) ):
        lSet, rSet = binSplitDataSet(testSet,tree['spInd'],tree['spVal'])
    #分别对左右子树进行剪枝
    if (isTree(tree['left'])):
        tree['left'] = prune(tree['left'],lSet)
    if (isTree(tree['right'])):
        tree['right'] = prune(tree['right'],rSet)
    #最后一次递归进行是否合并判断,进入最后一次递归的条件是传入树的左右子树不在是子树,而是叶节点
    if ((not isTree(tree['left'])) and (not isTree(tree['right']))):
        #进行二值划分
        lSet, rSet = binSplitDataSet(testSet, tree['spInd'], tree['spVal'])
        #计算没有合并的误差和
        notMergeErr = sum(np.power(lSet[:,-1] - tree['left'],2)) + sum(np.power(rSet[:,-1] - tree['right'],2))
        #计算合并的话节点值,已经到了叶节点,所以之间计算平均值
        meanNodeNum = (tree['left'] + tree['right'])/2.0
        #计算合并后的误差和
        mergeErr = sum(np.power(testSet[:,-1] - meanNodeNum,2))
        #如果合并后的误差方和减少了,则合并
        if mergeErr < notMergeErr:
            print("merging")
            return meanNodeNum
        #否则,返回该树
        else:
            return tree
    #最后返回完成剪枝之后的树
    else:
        return tree

myTree = createTree(myMat2,ops=(0,1))
myDatTree = loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTree)
prune(myTree,myMat2Test)
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging
merging





{'left': {'left': {'left': {'left': 92.523991499999994,
    'right': {'left': {'left': {'left': 112.386764,
       'right': 123.559747,
       'spInd': 0,
       'spVal': 0.960398},
      'right': 135.83701300000001,
      'spInd': 0,
      'spVal': 0.958512},
     'right': 111.2013225,
     'spInd': 0,
     'spVal': 0.956951},
    'spInd': 0,
    'spVal': 0.965969},
   'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 96.41885225,
              'right': 69.318648999999994,
              'spInd': 0,
              'spVal': 0.948822},
             'right': {'left': {'left': 110.03503850000001,
               'right': {'left': 65.548417999999998,
                'right': {'left': 115.75399400000001,
                 'right': {'left': {'left': 94.396114499999996,
                   'right': 85.005351000000005,
                   'spInd': 0,
                   'spVal': 0.912161},
                  'right': {'left': {'left': 106.814667,
                    'right': 118.513475,
                    'spInd': 0,
                    'spVal': 0.908629},
                   'right': {'left': 87.300624999999997,
                    'right': {'left': {'left': 100.133819,
                      'right': 108.09493399999999,
                      'spInd': 0,
                      'spVal': 0.900699},
                     'right': {'left': 82.436685999999995,
                      'right': {'left': 98.544549499999988,
                       'right': 106.16859550000001,
                       'spInd': 0,
                       'spVal': 0.872199},
                      'spInd': 0,
                      'spVal': 0.888426},
                     'spInd': 0,
                     'spVal': 0.892999},
                    'spInd': 0,
                    'spVal': 0.901421},
                   'spInd': 0,
                   'spVal': 0.901444},
                  'spInd': 0,
                  'spVal': 0.910975},
                 'spInd': 0,
                 'spVal': 0.925782},
                'spInd': 0,
                'spVal': 0.934853},
               'spInd': 0,
               'spVal': 0.936524},
              'right': {'left': {'left': 89.20993,
                'right': 76.240983999999997,
                'spInd': 0,
                'spVal': 0.847219},
               'right': 95.893130999999997,
               'spInd': 0,
               'spVal': 0.84294},
              'spInd': 0,
              'spVal': 0.85497},
             'spInd': 0,
             'spVal': 0.944221},
            'right': 60.552307999999996,
            'spInd': 0,
            'spVal': 0.841625},
           'right': 124.87935300000001,
           'spInd': 0,
           'spVal': 0.841547},
          'right': {'left': 76.723834999999994,
           'right': {'left': 59.342323,
            'right': 70.054507999999998,
            'spInd': 0,
            'spVal': 0.819722},
           'spInd': 0,
           'spVal': 0.823848},
          'spInd': 0,
          'spVal': 0.833026},
         'right': {'left': 118.319942,
          'right': {'left': 99.841379000000003,
           'right': 112.981216,
           'spInd': 0,
           'spVal': 0.811363},
          'spInd': 0,
          'spVal': 0.811602},
         'spInd': 0,
         'spVal': 0.815215},
        'right': 73.494399250000001,
        'spInd': 0,
        'spVal': 0.806158},
       'right': {'left': 114.4008695,
        'right': 102.26514075,
        'spInd': 0,
        'spVal': 0.786865},
       'spInd': 0,
       'spVal': 0.790312},
      'right': 64.041940999999994,
      'spInd': 0,
      'spVal': 0.769043},
     'right': 115.199195,
     'spInd': 0,
     'spVal': 0.763328},
    'right': 78.085643250000004,
    'spInd': 0,
    'spVal': 0.759504},
   'spInd': 0,
   'spVal': 0.952833},
  'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 110.90282999999999,
         'right': {'left': 103.345308,
          'right': 108.55391899999999,
          'spInd': 0,
          'spVal': 0.710234},
         'spInd': 0,
         'spVal': 0.716211},
        'right': 135.41676699999999,
        'spInd': 0,
        'spVal': 0.70889},
       'right': {'left': {'left': {'left': {'left': 106.18042699999999,
           'right': 105.062147,
           'spInd': 0,
           'spVal': 0.70639},
          'right': 115.58660500000001,
          'spInd': 0,
          'spVal': 0.699873},
         'right': 92.470635999999999,
         'spInd': 0,
         'spVal': 0.69892},
        'right': {'left': 120.521925,
         'right': {'left': 101.91115275,
          'right': 112.78136649999999,
          'spInd': 0,
          'spVal': 0.666452},
         'spInd': 0,
         'spVal': 0.689099},
        'spInd': 0,
        'spVal': 0.698472},
       'spInd': 0,
       'spVal': 0.706961},
      'right': {'left': 121.98060700000001,
       'right': {'left': 115.687524,
        'right': 112.715799,
        'spInd': 0,
        'spVal': 0.652462},
       'spInd': 0,
       'spVal': 0.661073},
      'spInd': 0,
      'spVal': 0.665329},
     'right': 82.500765999999999,
     'spInd': 0,
     'spVal': 0.642707},
    'right': 140.61394100000001,
    'spInd': 0,
    'spVal': 0.642373},
   'right': {'left': {'left': {'left': {'left': 82.713621000000003,
       'right': {'left': 91.656616999999997,
        'right': 93.645292999999995,
        'spInd': 0,
        'spVal': 0.632691},
       'spInd': 0,
       'spVal': 0.637999},
      'right': {'left': 117.62834599999999,
       'right': 105.970743,
       'spInd': 0,
       'spVal': 0.624827},
      'spInd': 0,
      'spVal': 0.628061},
     'right': 82.04976400000001,
     'spInd': 0,
     'spVal': 0.623909},
    'right': {'left': 168.180746,
     'right': {'left': {'left': {'left': {'left': {'left': {'left': 93.521395999999996,
           'right': {'left': 130.37852899999999,
            'right': {'left': 111.9849935,
             'right': {'left': 82.589327999999995,
              'right': {'left': 114.872056,
               'right': 108.43539199999999,
               'spInd': 0,
               'spVal': 0.569327},
              'spInd': 0,
              'spVal': 0.571214},
             'spInd': 0,
             'spVal': 0.582311},
            'spInd': 0,
            'spVal': 0.589806},
           'spInd': 0,
           'spVal': 0.599142},
          'right': 82.903944999999993,
          'spInd': 0,
          'spVal': 0.560301},
         'right': 129.06244849999999,
         'spInd': 0,
         'spVal': 0.553797},
        'right': {'left': 83.114502000000002,
         'right': {'left': 97.340526499999996,
          'right': 90.995536000000001,
          'spInd': 0,
          'spVal': 0.537834},
         'spInd': 0,
         'spVal': 0.546601},
        'spInd': 0,
        'spVal': 0.548539},
       'right': {'left': {'left': 129.76674299999999,
         'right': 124.795495,
         'spInd': 0,
         'spVal': 0.531944},
        'right': 116.17616200000001,
        'spInd': 0,
        'spVal': 0.51915},
       'spInd': 0,
       'spVal': 0.533511},
      'right': {'left': 101.075609,
       'right': {'left': 93.292828999999998,
        'right': 96.403373000000002,
        'spInd': 0,
        'spVal': 0.508542},
       'spInd': 0,
       'spVal': 0.508548},
      'spInd': 0,
      'spVal': 0.513332},
     'spInd': 0,
     'spVal': 0.606417},
    'spInd': 0,
    'spVal': 0.613004},
   'spInd': 0,
   'spVal': 0.640515},
  'spInd': 0,
  'spVal': 0.729397},
 'right': {'left': {'left': {'left': {'left': {'left': 8.5367700000000006,
      'right': 27.729263,
      'spInd': 0,
      'spVal': 0.487381},
     'right': 5.224234,
     'spInd': 0,
     'spVal': 0.483803},
    'right': {'left': -9.7129250000000003,
     'right': -23.777531,
     'spInd': 0,
     'spVal': 0.46568},
    'spInd': 0,
    'spVal': 0.467383},
   'right': {'left': 30.051931,
    'right': 17.171057000000001,
    'spInd': 0,
    'spVal': 0.463241},
   'spInd': 0,
   'spVal': 0.465561},
  'right': {'left': -34.044555000000003,
   'right': {'left': {'left': {'left': {'left': {'left': -4.1911744999999998,
        'right': {'left': {'left': {'left': {'left': 19.745224,
            'right': 15.224266,
            'spInd': 0,
            'spVal': 0.428582},
           'right': -21.594268,
           'spInd': 0,
           'spVal': 0.426711},
          'right': 44.161493,
          'spInd': 0,
          'spVal': 0.418943},
         'right': {'left': -26.419288999999999,
          'right': 0.63593000000000011,
          'spInd': 0,
          'spVal': 0.403228},
         'spInd': 0,
         'spVal': 0.412516},
        'spInd': 0,
        'spVal': 0.437652},
       'right': 23.197474,
       'spInd': 0,
       'spVal': 0.388789},
      'right': {'left': {'left': {'left': -29.007783,
         'right': {'left': {'left': 13.583555,
           'right': 5.2411960000000004,
           'spInd': 0,
           'spVal': 0.377383},
          'right': -8.2282969999999995,
          'spInd': 0,
          'spVal': 0.373501},
         'spInd': 0,
         'spVal': 0.378965},
        'right': {'left': -32.124495000000003,
         'right': {'left': -9.9938275000000001,
          'right': -26.851234812500003,
          'spInd': 0,
          'spVal': 0.350725},
         'spInd': 0,
         'spVal': 0.35679},
        'spInd': 0,
        'spVal': 0.370042},
       'right': {'left': 22.286959625000001,
        'right': {'left': {'left': -20.397333499999998,
          'right': -49.939515999999998,
          'spInd': 0,
          'spVal': 0.310956},
         'right': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': {'left': 8.8147249999999993,
                   'right': {'left': -18.051317999999998,
                    'right': {'left': -1.7983769999999999,
                     'right': {'left': -14.988279,
                      'right': -14.391613,
                      'spInd': 0,
                      'spVal': 0.290749},
                     'spInd': 0,
                     'spVal': 0.295993},
                    'spInd': 0,
                    'spVal': 0.297107},
                   'spInd': 0,
                   'spVal': 0.300318},
                  'right': {'left': 35.623745999999997,
                   'right': {'left': -9.4575560000000003,
                    'right': {'left': 5.2805790000000004,
                     'right': 2.5579230000000002,
                     'spInd': 0,
                     'spVal': 0.264639},
                    'spInd': 0,
                    'spVal': 0.264926},
                   'spInd': 0,
                   'spVal': 0.273863},
                  'spInd': 0,
                  'spVal': 0.284794},
                 'right': {'left': {'left': -9.601409499999999,
                   'right': -30.812912000000001,
                   'spInd': 0,
                   'spVal': 0.228751},
                  'right': -2.266273,
                  'spInd': 0,
                  'spVal': 0.228628},
                 'spInd': 0,
                 'spVal': 0.25807},
                'right': 6.0992389999999999,
                'spInd': 0,
                'spVal': 0.228473},
               'right': {'left': -16.427370249999999,
                'right': -2.6781804999999999,
                'spInd': 0,
                'spVal': 0.202161},
               'spInd': 0,
               'spVal': 0.211633},
              'right': 9.5773855000000001,
              'spInd': 0,
              'spVal': 0.193282},
             'right': {'left': {'left': {'left': -14.740059,
                'right': -6.5125060000000001,
                'spInd': 0,
                'spVal': 0.166431},
               'right': -27.405211000000001,
               'spInd': 0,
               'spVal': 0.164134},
              'right': 0.225886,
              'spInd': 0,
              'spVal': 0.156273},
             'spInd': 0,
             'spVal': 0.166765},
            'right': {'left': 7.5573490000000003,
             'right': 7.3367839999999998,
             'spInd': 0,
             'spVal': 0.13988},
            'spInd': 0,
            'spVal': 0.156067},
           'right': -29.087463,
           'spInd': 0,
           'spVal': 0.138619},
          'right': 22.478290999999999,
          'spInd': 0,
          'spVal': 0.131833},
         'spInd': 0,
         'spVal': 0.309133},
        'spInd': 0,
        'spVal': 0.324274},
       'spInd': 0,
       'spVal': 0.335182},
      'spInd': 0,
      'spVal': 0.382037},
     'right': -39.524461000000002,
     'spInd': 0,
     'spVal': 0.130626},
    'right': {'left': 22.891674999999999,
     'right': {'left': {'left': 6.1965159999999999,
       'right': {'left': -16.106164,
        'right': {'left': -1.2931950000000001,
         'right': -10.137104000000001,
         'spInd': 0,
         'spVal': 0.085873},
        'spInd': 0,
        'spVal': 0.10796},
       'spInd': 0,
       'spVal': 0.108801},
      'right': {'left': 37.820658999999999,
       'right': {'left': -24.132225999999999,
        'right': {'left': 15.824970500000001,
         'right': {'left': -15.160836,
          'right': {'left': {'left': {'left': 6.6955669999999996,
             'right': -3.131497,
             'spInd': 0,
             'spVal': 0.055862},
            'right': -13.731698,
            'spInd': 0,
            'spVal': 0.053764},
           'right': 4.0916259999999998,
           'spInd': 0,
           'spVal': 0.044737},
          'spInd': 0,
          'spVal': 0.061219},
         'spInd': 0,
         'spVal': 0.068373},
        'spInd': 0,
        'spVal': 0.080061},
       'spInd': 0,
       'spVal': 0.084661},
      'spInd': 0,
      'spVal': 0.085111},
     'spInd': 0,
     'spVal': 0.124723},
    'spInd': 0,
    'spVal': 0.126833},
   'spInd': 0,
   'spVal': 0.455761},
  'spInd': 0,
  'spVal': 0.457563},
 'spInd': 0,
 'spVal': 0.499171}
模型树

用树来对数据建模,除了把叶节点简单地设定为常数值之外,还有一种方法是把叶节点设定为分段线性函数,这里所谓的分段线性(piecewise linear)是指模型由多个线性片段组成

程序清单9-4 模型树的叶节点生成函数
#在创建树函数中,叶节点的模式函数和误差计算函数参数是通过函数引用给出的,模式树的构建过程与回归树的构建过程是一致的
#只是模式树的节点不是数值,而是回归函数(叶节点模式),误差计算函数也和回归树中误差计算不一样,所以模式树构建只需要传入这两个函数的应用就可实现树的构建
#对数据集进行标准线性回归,返回系数向量X矩阵和Y矩阵
def linearSlove(dataSet):
    #dataSet中包含了目标值列,标准线性回归中需要加入一列1,是为了将截距b的值加入ws向量中一起求出
    #获取数据集的行列大小
    m,n = np.shape(dataSet)
    #创建m*n大小的全1矩阵,这里相当于覆盖x矩阵中的最后一列
    X = np.mat(np.ones((m,n)))
    Y = np.mat(np.ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]
    Y = dataSet[:,-1]
    xtx = X.T*X
    if np.linalg.det(xtx) == 0.0:
        raise NameError("矩阵不可逆")
    ws = xtx.I * (X.T*Y)
    return ws,X,Y

#叶节点模式函数
def modelLeaf(dataSet):
    ws,X,Y = linearSlove(dataSet)
    #叶节点模式函数返回回归系数向量
    return ws

#计算误差函数
def modelErr(dataSet):
    ws, X, Y = linearSlove(dataSet)
    #预测
    yHat = X * ws
    return sum(np.power(Y - yHat,2))
myDat2 = loadDataSet('exp2.txt')
myMat2 = mat(myDat2)
myTree = createTree(myMat2, leafType=modelLeaf, errType=modelErr, ops=(1,10))
print(myTree)

{'spInd': 0, 'spVal': 0.285477, 'left': matrix([[  1.69855694e-03],
        [  1.19647739e+01]]), 'right': matrix([[ 3.46877936],
        [ 1.18521743]])}
程序清单9-5 用树回归进行预测的代码
# 回归树测试案例
# 为了和 modelTreeEval() 保持一致,保留两个输入参数
# 模型效果计较
# 线性叶子节点 预测计算函数 直接返回 树叶子节点 值
def regTreeEval(model, inDat):
    return float(model)
# 模型树测试案例
# 对输入数据进行格式化处理,在原数据矩阵上增加第1列,元素的值都是1,
# 也就是增加偏移值,和我们之前的简单线性回归是一个套路,增加一个偏移量
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1,n+1))) # 增加一列
    X[:,1:n+1]=inDat
    # print X # [[  1.  12.]]
    # print '==============='
    # print model
    '''[[ -2.87684083]
        [ 10.20804482]]
    '''
    return float(X*model) # 返回 值乘以 线性回归系数

# 计算预测的结果
# 在给定树结构的情况下,对于单个数据点,该函数会给出一个预测值。
# modelEval是对叶节点进行预测的函数引用,指定树的类型,以便在叶节点上调用合适的模型。
# 此函数自顶向下遍历整棵树,直到命中叶节点为止,一旦到达叶节点,它就会在输入数据上
# 调用modelEval()函数,该函数的默认值为regTreeEval()
def treeForeCast(tree, inData, modelEval=regTreeEval):
    """
        Desc:
            对特定模型的树进行预测,可以是 回归树 也可以是 模型树
        Args:
            tree -- 已经训练好的树的模型
            inData -- 输入的测试数据
            modelEval -- 预测的树的模型类型,可选值为 regTreeEval(回归树) 或 modelTreeEval(模型树),默认为回归树
        Returns:
            返回预测值
        """
    if not isTree(tree):
        return modelEval(tree, inData) # 返回 叶子节点 预测值
    # print inData[tree['spInd']]
    # print '==============='
    # print tree['spVal']
    # print inData[tree['spInd']] > tree['spVal']
    '''
     [[ 12.]]
     ===============
    10.0
     [[ True]]
     '''
    if inData[tree['spInd']] > tree['spVal']: # 左树 [[ True]]或者 [[False]]
        if isTree(tree['left']):
            # 还是树 则递归调用
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            # 计算叶子节点的值 并返回
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):  # 右树
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)  # 计算叶子节点的值 并返回

# 得到预测值
def createForeCast(tree, testData, modelEval=regTreeEval):
    m=len(testData)
    # print mat(testData[0]) # [[ 12.]]
    # print mat(testData[1]) # [[ 19.]]
    yHat = mat(zeros((m,1)))
    for i in range(m):
        yHat[i,0] = treeForeCast(tree, mat(testData[i]), modelEval)
        # if i == 1:
        #     break
    return yHat

trainSet = np.mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testSet = np.mat(loadDataSet('bikeSpeedVsIq_test.txt'))
#回归树预测
myTree = createTree(trainSet,ops=(1,20))
#数据集中最后一列是目标值
yHat = createForeCast(myTree,testSet[:,0])
corr = np.corrcoef(yHat,testSet[:,1],rowvar=0)[0,1]
print("regTreeCorr:",corr) #回归树
regTreeCorr: 0.964085231822
myTree = createTree(trainSet,modelLeaf,modelErr,ops=(1, 20))
    # 数据集中最后一列是目标值
yHat = createForeCast(myTree, testSet[:, 0],modelTreeEval)
corr = np.corrcoef(yHat, testSet[:, 1], rowvar=0)[0, 1]
print("modelTreeCorr:", corr) #模型树

modelTreeCorr: 0.976041219138
#进行标准线性回归
ws,X,Y = linearSlove(trainSet)
print("ws:",ws)
yHat = np.zeros((np.shape(testSet)[0],1))
for i in range(np.shape(testSet)[0]):
    yHat[i] = testSet[i,0] * ws[1,0] + ws[0,0]
    corr = np.corrcoef(yHat,testSet[:,1],rowvar=0)[0,1]
print("linerCorr:",corr)

ws: [[ 37.58916794]
 [  6.18978355]]
linerCorr: 0.943468423567
'''
from tkinter import *
root = Tk()
myLabel = Label(root, text ="Hello World")
myLabel.grid()
root.mainloop()
'''
'\nfrom tkinter import *\nroot = Tk()\nmyLabel = Label(root, text ="Hello World")\nmyLabel.grid()\nroot.mainloop()\n'
程序清单9-7 Matplotlib 和 Tkinter的代码集成
import matplotlib

#将matplotlib后端设置为TkAgg
matplotlib.use('TkAgg')
#将tkinter和matplotlib关联起来
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure


def reDraw(tolS, tolN):
    reDraw.f.clf()  # 清空画布
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree = createTree(reDraw.rawDat,modelLeaf,modelErr, (tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat,modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS, tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    #绘制散点图
    reDraw.a.scatter(reDraw.rawDat[:, 0].tolist(), reDraw.rawDat[:, 1].tolist(), s=5)  # use scatter for data set
    #绘制连续图
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)  # use plot for yHat
    reDraw.canvas.show()

def getInputs():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print("tolN请输入整数值")
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get())
    except:
        tolS = 1.0
        print("tolS请输入浮点值")
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS

def drawNewTree():
    tolN, tolS = getInputs()  # 从输入文本框中获取参数
    reDraw(tolS, tolN)  #绘制图
程序清单9-6 用于构建树管理器界面的Tkinter小部件

#创建跟部件
root = Tk()

#Label(root,text="Plot Place Holder").grid(row = 0 ,columnspan = 3)
reDraw.f = Figure(figsize=(5, 4), dpi=100)  # 创建画布
reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
reDraw.canvas.show()
reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)


#第一行标签
Label(root,text="tolN").grid(row = 1 ,column = 0)
tolNentry = Entry(root)
tolNentry.grid(row = 1 ,column = 1)
tolNentry.insert(0,'10')
#第二行标签
Label(root,text="tolS").grid(row = 2 ,column = 0)
tolSentry = Entry(root)
tolSentry.grid(row = 2,column = 1)
tolSentry.insert(0,'1.0')

#按钮
Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
#复选框变量
chkBtnVar = IntVar()
chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar)
chkBtn.grid(row=3, column=0, columnspan=2)
#获取变量
reDraw.rawDat = mat(loadDataSet('sine.txt'))
reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
reDraw(1.0, 10)

root.mainloop()

在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值