cart树的代码示例参考机器学习实战

本文详细介绍了如何通过代码实现CART(分类与回归树)算法,结合实例展示了CART在机器学习中的应用,帮助读者理解决策树的工作原理。
摘要由CSDN通过智能技术生成
from numpy import *
import numpy as np
import pickle
import matplotlib.pyplot as plt
import sys
from matplotlib.font_manager import FontProperties  # 设置字体属性

def loadDataSet(fileName):
    '''
    读取一个一tab键为分隔符的文件,然后将每行的内容保存成一组浮点数
    '''
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))
        dataMat.append(fltLine)
    return dataMat


def binSplitDataSet(dataSet, feature, value):
    '''
    数据集切分函数
    '''
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

#--------------回归树所需子函数---------------#
def regLeaf(dataSet):
    '''负责生成叶节点'''
    #当chooseBestSplit()函数确定不再对数据进行切分时,将调用本函数来得到叶节点的模型。
    #在回归树中,该模型其实就是目标变量的均值。
    return np.mean(dataSet[:,-1])


def regErr(dataSet):
    '''
    误差估计函数,该函数在给定的数据上计算目标变量的平方误差,这里直接调用均方差函数
    '''
    return var(dataSet[:,-1]) * shape(dataSet)[0]#返回总方差


#--------------回归树子函数  END  --------------#


def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''
    用最佳方式切分数据集和生成相应的叶节点
    '''
    #ops为用户指定参数,用于控制函数的停止时机
    tolS = ops[0]; tolN = ops[1]
    #如果所有值相等则退出
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    #在所有可能的特征及其可能取值上遍历,找到最佳的切分方式
    #最佳切分也就是使得切分后能达到最低误差的切分
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:, featIndex].T.A.tolist()[0]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #如果误差减小不大则退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    #如果切分出的数据集很小则退出
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    #提前终止条件都不满足,返回切分特征和特征值
    return bestIndex,bestValue

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    '''
    树构建函数
    leafType:建立叶节点的函数
    errType:误差计算函数
    ops:包含树构建所需其他参数的元组
    '''
    #选择最优的划分特征
    #如果满足停止条件,将返回None和某类模型的值
    #若构建的是回归树,该模型是一个常数;如果是模型树,其模型是一个线性方程
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None: return val #
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    #将数据集分为两份,之后递归调用继续划分
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree


def storeTree(inputTree, filename):
    with open(filename, 'wb') as fw:
        pickle.dump(inputTree, fw)

if __name__ == '__main__':
    myDat = loadDataSet('ex00.txt')
    x=[x[0] for x in myDat]
    y=[y[1] for y in myDat]
    myMat = mat(myDat)
    retTree = createTree(myMat)
    print(retTree)
    storeTree(retTree, 'retTree.txt')

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值