机器学习实战Ch09-树回归(CART)

1、CART生成

回归树采用平方误差最小化准则。
分类树采用基尼指数最小化准则。

from numpy import *
from matplotlib import pyplot as plt


def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float, curLine))  # 将数据映射成浮点型,返回的是map的地址        dataMat.append(fltLine)
        dataMat.append(list(fltLine))
    return dataMat


def plotData(dataMat):
    x = dataMat[:, 0].tolist()
    y = dataMat[:, 1].tolist()
    plt.scatter(x, y)
    plt.title('DataSet')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()


def plotData2(dataMat):
    x = dataMat[:, 1].tolist()
    y = dataMat[:, 2].tolist()
    plt.scatter(x, y)
    plt.title('DataSet')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()


def binSplitDataSet(dataSet, feature, value):
    """
    根据某个特征及值对数据进行切分
    :param dataSet: 数据集
    :param feature: 特征
    :param value: 特征值
    :return: 切分后的数据
    """
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1


def regLeaf(dataSet):
    """
    生成叶节点
    :param dataSet: 数据集
    :return: 使用均值作为叶节点的值
    """
    return mean(dataSet[:, -1])


# 误差估计函数
def regErr(dataSet):
    """
    误差估计函数
    :param dataSet: 数据集
    :return: 总方差
    """
    return var(dataSet[:, -1]) * shape(dataSet)[0]


def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
    """
    找到数据的最佳二元切分方式函数
    :param dataSet:数据组
    :param leafType:生成叶节点
    :param errType:误差估计函数
    :param ops:用户自定义参数构成的元组
    :return:
        bestIndex:最佳切分特征的下标
        bestValue:最优切分值
    """
    """
    伪代码:
        首先判断是否所有值相等:
         否:计算误差值,初始化最佳误差,最优切分特征下标,最优切分值
            遍历每个特征:
                遍历特征的每个值:
                    将数据切分为两部分
                    判断是否小于最少切分样本值:
                    是:跳出循环
                    否:计算切分后误差
                        如果切分后误差小于当前最佳误差,则更新最优切分特征下标,最优切分值,最优误差
            如果误差下降值小于最小误差值,则返回最佳切分特征值和最优切分值
    """
    tolS = ops[0];
    tolN = ops[1]  # tolS是允许的最小误差下降值,tolN是切分的最少样本
    if len(set(dataSet[:, -1].T.tolist()[0])) == 1:  # 如果所有值相等则退出,根据set特性
        return None, leafType(dataSet)
    m, n = shape(dataSet)
    S = errType(dataSet) #计算误差值
    bestS = inf;
    bestIndex = 0;
    bestValue = 0
    for featIndex in range(n - 1):  #遍历每个特征值
        for splitVal in set((dataSet[:, featIndex].T.tolist())[0]):  # 遍历特征的每个值
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)  #将数据集切分为两部分
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue  #如果小于切分的最少样本数则直接跳到下一次循环
            newS = errType(mat0) + errType(mat1)  # 计算新的误差
            if newS < bestS:  # 如果小于最佳误差则更新
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果误差下降值小于允许的最小误差下降值,则不需要切分
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    # 切分数据集
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 如果切分后的数据集个数小于tolN,则不进行切分
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex, bestValue


def createTree(dataSet, leafType=regLeaf, errType=regErr,ops=(1, 4)):
    """
    建树
    :param dataSet:数据集
    :param leafType:生成叶子节点函数
    :param errType:误差函数
    :param ops:用户自定义参数,对树进行预剪枝
    :return:树
    """
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

# 测试
if __name__ == '__main__':

    ex00Data = loadDataSet('ex00.txt')
    ex00Mat = mat(ex00Data)
    plotData(ex00Mat)
    myTree1 = createTree(ex00Mat)
    print(myTree1)

    ex0Data = loadDataSet('ex0.txt')
    ex0Mat = mat(ex0Data)
    plotData2(ex0Mat)
    myTree2 = createTree(ex0Mat)
    print(myTree2)
    print(createTree(ex0Mat, ops=(0,1)))

2、树剪枝

(1)预剪枝

if __name__ == '__main__':
    ex2Data = loadDataSet('ex2.txt')
    ex2Mat = mat(ex2Data)
    plotData(ex2Mat)
    ex2Tree = createTree(ex2Mat)
    print(ex2Tree)   #非常多的叶子节点
    print(createTree(ex2Mat,ops=(10000,4))) #只有两个叶子节点,tolS对误差的数据集非常敏感

(2)后剪枝

def isTree(obj):
    """
    函数说明:判断是否是一棵树
    :param obj:
    :return:
    """
    return (type(obj).__name__=='dict')


def getMean(tree):
    """
    函数说明:递归函数,对树进行塌陷处理(返回树平均值)
    :param tree::return: 树的左右子树的均值
    """
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    return (tree['left'] + tree['right']) / 2.0


def prnue(tree, testData):
    """
    函数说明:后剪枝函数,使用测试数据对树进行剪枝
    :param tree:  待剪枝的树
    :param testData:   测试数据
    :return: 剪枝后的树
    """
    """
        树的后剪枝:
        伪代码:
            基于已有的树切分测试数据:
                如果存在任一子集是一棵树,则在该子集递归剪枝过程
                计算将当前两个叶节点合并后的误差
                计算没有合并时的误差
                如果合并可以降低误差则进行合并
    """
    #如果测试数据集为空,进行塌陷处理
    if(shape(testData[0]) == 0): return getMean(tree)
    # if isTree(tree['left']) and isTree(tree['right']):
    lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prnue(tree['left'], lSet)  #左子集不是单个节点,则对左子树进行剪枝
    if isTree(tree['right']): tree['right'] = prnue(tree['right'], rSet) #右子集不是单个节点,则对右子树进行剪枝
    if not isTree(tree['left']) and not isTree(tree['right']):  #左右子集都是单个节点
        errNoMerge = sum(power((lSet[:,-1] - tree['left']),2)) + sum(power((rSet[:,-1] - tree['right']),2)) #未合并的误差
        treeMean = getMean(tree)
        errMerge = sum(power(testData[:,-1] - treeMean, 2))   #合并后误差
        print(errNoMerge, errMerge)
        if errMerge < errNoMerge:   #如果合并后误差小于合并前误差则合并
            print("merging")
            return treeMean
        else: return tree
    else: return tree


if __name__ == '__main__':
    ex2Data = loadDataSet('ex2.txt')
    ex2Mat = mat(ex2Data)
    plotData(ex2Mat)
    ex2Tree = createTree(ex2Mat)
    print(ex2Tree)

    ex2TestData = loadDataSet('ex2test.txt')
    prnueTree = prnue(ex2Tree, mat(ex2TestData))
    print(prnueTree)

3、进行预测

简单线性回归、回归树、模型树的对比
使用决定系数R^2值来判断预测效果,越接近于1越好。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值