机器学习实战 -ch09.树回归(CART算法)

原创 2016年08月29日 19:52:53

一. CART vs ID3

这里写图片描述

二. 算法代码及注释

没有考虑后面的“树回归和标准回归的对比”,对于剪枝原理也还有待深入的理解

# -*- coding:utf-8 -*-  
from numpy import *

#读取数据到矩阵

def loadDataSet(filename):
    dataMat = []

    fr = open(filename)

    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine)  # 将每个数据转成浮点型
        dataMat.append(fltLine)
    return dataMat
#-----------------线性回归算法--------------------------------------#
# 线性回归算法实现,返回构建的模型
def linearSolve(dataSet):
    m,n = shape(dataSet)

    # 创建X与Y矩阵,并将X,Y中的数据格式化
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X

    # 判断是否为奇异矩阵
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y
 #-------------------------模型树---------------------------#

def modeLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

def modeErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X *ws
    return sum(power(Y -yHat,2))
#----------------------CART算法----------------------------------#
#“二元切分法”切分数据集:在给定特征和特征值的情况下,通过数组过滤的方式将输入的数据集合切分得到两个子集并返回
def binSplitDataSet(dataSet,feature,value):  #输入:数据集合,待切分特征,该特征的某个值
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]  #注意此处书上源码有误
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

#生成叶节点的函数
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

#误差计算函数
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]  #总方差 = 均方差 × 样本数目


#切分函数:根据“总方差”确定最佳二元切分方式(是回归书构建的核心函数!)
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
    tolS = ops[0]  #允许的误差下降值
    tolN = ops[1]  #切分的最少样本数量
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:  #当剩余特征值的数目为1的时候,切分结束,直接返回(tolist()是将当前对象转换为python的list对象返回)
        return None, leafType(dataSet)
    m,n = shape(dataSet)   #初始化循环变量
    S = errType(dataSet)
    bestS = inf 
    bestIndex = 0
    bestValue = 0
    #需找最佳二元划分方式: 遍历数据集合中所有特征的所有样本(注意描述顺序),需找到能够使得切分后数据集合效果提升(误差下降最多)对应的特征和特征值返回
    for  featIndex in range(n-1):
        for splitVal  in set(dataSet[:,featIndex].flat): #处书上源码有误(flat()是获得将数组展开为一维的迭代器),这里与dataSet[:,].T.tolist()[0]等效)
            mat0,mat1 = binSplitDataSet(dataSet,featIndex,splitVal)
            if(shape(mat0)[0] < tolN or shape(mat1)[0] < tolN):   #当样本数量小于预设值tolN时,直接进行下一轮判断
                continue
            newS =errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:                         #当误差减小效果不明显(误差下降值未达到于设置值tolS)时,直接返回None和叶节点
        return None, leafType(dataSet)
    mat0,mat1 = binSplitDataSet(dataSet,bestIndex,bestValue)
    if(shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #当样本数量小于预设值tolN时
        return None, leafType(dataSet)
    return bestIndex, bestValue


#构建树:根据函数chooseBestSplit()切分数据,递归的构建二叉树
#输入:数据集合,生成叶节点的函数,误差计算函数,构建树所需要的其他参数元组 
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):   
    #尝试将数据集合分成两部分
    #切分函数按chooseBestSplit()函数进行
    feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
    if feat == None:   # 如果满足停止条件,chooseBestSplit()函数返回None和某类模型的值(回归树:常数;模型树:线性方程)
        return val
    #将数据集合分成左子树、右子树两部分
    retTree= {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet,rSet = binSplitDataSet(dataSet,feat,val)
    #继续递归左子树和右子树
    retTree['left'] = createTree(lSet,leafType,errType,ops)
    retTree['right'] = createTree(rSet,leafType,errType,ops)
    return retTree


#树剪枝:
#伪代码
#基于已有的树切分测试数据:
#     如果存在任一子集是一棵树,则在该子集地柜剪枝过程
#     计算将当前两个叶节点合并后的误差
#     计算不合并的误差
#     如果合并的误差会降低的话,就将叶节点合并

# 测试输入变量是否位一棵树(也就相当于测试当前处理的节点是否为叶节点)
def isTree(obj):
    return (type(obj).__name__ == 'dict')

# 递归函数,从上往下遍历树直到叶节点为止,然后对树做塌陷处理(即返回树的平均值)
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right'])/2.0

#剪枝主函数: 
#输入: 待剪枝的树,剪枝所需的测试数据
def prune(tree,testData):
    #测试集合为空,则直接对树进行塌陷处理
    if shape(testData)[0] == 0:  
        return getMean(tree)
    #测试集合非空,则递归地调用prune()函数对测试数据进行切分
    if (isTree(tree['right'])) or (isTree(tree['left'])):
        lSet ,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right'] = prune(tree['right'],rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet ,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) + sum(power(rSet[:,-1] - tree['right'],2)) #power()用于求幂
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge:
            print "merging"
            return  treeMean
        else:
            return tree
    else:
        return tree
版权声明:本文为博主原创文章,未经博主允许不得转载。

相关文章推荐

周志华《机器学习》课后习题解答系列(五):Ch4.4 - 编程实现CART算法与剪枝操作

基于训练集构建的完全决策树易陷入过拟合,为提升模型泛化能力,通常需要对树进行剪枝。此处基于基尼系数构建出决策树(CART算法),然后编程实现预剪枝和后剪枝操作,最后分析比较了他们的作用。...

《机器学习实战》基于信息论的三种决策树算法(ID3,C4.5,CART)

决策树是通过一系列规则对数据进行分类的过程,他提供一种在什么条件下会得到什么值的类似规则方法,决策树分为分类树和回归树,分类树对离散变量最决策树,回归树对连续变量做决策树如果不考虑效率等,那么样本所有...

机器学习实战-CART分类回归树

树回归           虽然线性回归有强大的功能,但是在遇到数据具有很多特征时且特征之间具有复杂的关系时,构建全局的模型就显得比较难,而且也比较笨重,而且实际中处理的数据一般都是非线性的,不可...

机器学习算法之CART(分类回归树)概要

分类回归树  classification and regression tree(C&RT)  racoon 优点 (1)可自动忽略对目标变量没有贡献的属性变量,也为判断属性变量的重要性,减少变量...

机器学习算法之CART(分类和回归树)

CART算法介绍: 分类和回归树(CART)是应用广泛的决策树学习方法。CART同样由特征选择,树的生成和减枝组成,既可以用于分类也可以用于回归。CART的生成就是递归的构建二叉决策树的过程。对回归...

简单易学的机器学习算法——分类回归树CART

一、树回归的概念 二、

机器学习算法-分类回归树CART

本文转载http://www.cnblogs.com/zhangchaoyang/articles/270992

机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树

摘要: Classification And Regression Tree(CART)是一种很重要的机器学习算法,既可以用于创建分类树(Classification Tree),也可以用于创建...

机器学习实战笔记_09_树回归_代码错误修正

本人用的是python 2.7,但是敲击书上的源代码,总是运行错误,发现代码有两处错误,可以把我的代码和书上的代码对照,错误地方已经标出regTrees.py from numpy import * ...

机器学习实战——ch8.2 回归之预测乐高玩具价格

这部分由于书上提供的Google的购物API已经关闭,所以只能在实验楼上完成了这个实验(这一次,我只是代码的搬运工) 完整的代码及注释:#-*- coding: utf-8 -*- from num...
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:机器学习实战 -ch09.树回归(CART算法)
举报原因:
原因补充:

(最多只允许输入30个字)