机器学习笔记——决策树(CART方法)

本文详细介绍了机器学习中的决策树算法,重点探讨了CART(Classification and Regression Trees)方法。通过实例解析了CART如何构建分类和回归树,讨论了特征选择的标准和剪枝策略,帮助读者掌握决策树模型的构建过程。
摘要由CSDN通过智能技术生成
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 24 11:14:58 2019

@author:wangtao_zuel

E-mail:wangtao_zuel@126.com

决策树CART方法

"""

import numpy as np
import pandas as pd

def loadData(filepath,fileType):
    """
    返回矩阵形式的数据,当样本类别为str类型时,应当相应修改样本读取方式
    """
    if fileType == 'xlsx':
        data = pd.read_excel(filepath)
    elif fileType == 'csv':
        data = pd.read_csv(filepath)
    else:
        data = pd.read_csv(filepath,sep='\t',header=None)
    data = np.mat(data)
    
    return data

def binSplitDataSet(dataSet,featInd,featVal):
    """
    按照特征(序号)、特征值将样本二分,这里统一将小于的部分放在左边(matL)
    """
    matL = dataSet[np.nonzero(dataSet[:,featInd] <= featVal)[0],:]
    matR = dataSet[np.nonzero(dataSet[:,featInd] > featVal)[0],:]
    
    return matL,matR

def regLeaf(dataSet):
    """
    叶节点创建
    这里返回分支下的分类平均值,适用于回归情况
    """
    return np.mean(dataSet[:,-1])

def maxLeaf(dataSet):
    """
    叶节点创建
    这类返回最多的分类
    """
    results = uniqueCount(dataSet)
    
    return max(results,key=results.get)

def uniqueCount(dataMat):
    """
    统计各类别样本个数
    注意这里使用的是矩阵类数据,若使用其他类型数据需修改遍历循环部分“dataSet[:,-1].T.tolist()[0]”
    """
    results = {}
    for sample in dataMat[:,-1].T.tolist()[0]:
        if sample not in results:
            results[sample] = 0
        results[sample] += 1
    
    return results

def regErr(dataSet):
    """
    误差计算
    这里使用的是平方误差,适合回归情况
    """
    var = np.var(dataSet[:,-1])
    m = dataSet.shape[0]
    err = m*var
    
    return err

def entErr(dataSet):
    """
    香农熵计算误差(混乱程度)
    """
    results = uniqueCount(dataSet)
    sampleNum = dataSet.shape[0]
    shannonEnt = 0.0
    for key in results:
        prob = float(results[key])/sampleNum
        shannonEnt -= prob*np.log2(prob)
    
    return shannonEnt

def giniErr(dataSet):
    """
    基尼不纯度计算误差(混乱程度)
    """
    sampleNum = dataSet.shape[0]
    results = uniqueCount(dataSet)
    imp = 0.0
    for k1 in results:
        p1 = float(results[k1])/sampleNum
        for k2 in results:
            if k1 == k2:
                continue
            p2 = float(results[k2])/sampleNum
            imp += p1*p2
    
    return imp

def chooseBestSplit(dataSet,leafType,errType,ops):
    """
    筛选最优分类特征、特征值
    """
    # 预剪枝参数,当优化(误差减小)过小或者分类太细(分支下样本数量太少),选择忽略
    minErr = ops[0]
    minNum = ops[1]
    # 若某分支下样本均为同一类,则返回建立叶节点
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None,leafType(dataSet)
    m,n = dataSet.shape
    # 不分类误差
    basicErr = errType(dataSet)
    bestErr = np.inf
    bestInd = 0
    bestVal = 0
    # 获取最小误差
    for featInd in range(n-1):
        for featVal in set(dataSet[:,featInd].T.tolist()[0]):
            matL,matR = binSplitDataSet(dataSet,featInd,featVal)
            # 判断分支下样本数目是否过小,预剪枝的一部分
            if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):
                continue
            newErr = errType(matL) + errType(matR)
            if newErr < basicErr:
                bestInd = featInd
                bestVal = featVal
                bestErr = newErr
    # 若优化太小,分类和不分类相差不大,则忽略优化,其实这部分也是预剪枝的一部分,
    if (basicErr - bestErr) < minErr:
        return None,leafType(dataSet)
    # 二次判断,和前面的部分并未冲突,这部分用于处理没有最优分类特征、特征值的情况
    matL,matR = binSplitDataSet(dataSet,bestInd,bestVal)
    if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):
        return None,leafType(dataSet)
    
    return bestInd,bestVal

def creatTree(dataSet,leafType,errType,ops):
    """
    递归创建树
    """
    # 选择最优的分类特征、特征值
    spInd,spVal = chooseBestSplit(dataSet,leafType,errType,ops)
    # 创建叶节点情况
    if spInd == None:
        return spVal
    # 创建子树
    tree = {}
    tree['spInd'] = spInd
    tree['spVal'] = spVal
    # 递归得到子分支树
    matL,matR = binSplitDataSet(dataSet,spInd,spVal)
    tree['left'] = creatTree(matL,leafType,errType,ops)
    tree['right'] = creatTree(matR,leafType,errType,ops)
    
    return tree

"""
# 后剪枝操作
"""

def isTree(obj):
    """
    判断分支下是否为子树,是则返回True
    """
    return (type(obj).__name__=='dict')

def getMean(tree):
    """
    塌陷处理,返回左右分支的平均值作为上一节点的值
    """
    if isTree(tree['left']):
        return getMean(tree['left'])
    if isTree(tree['right']):
        return getMean(tree['right'])
    
    return (tree['left']+tree['right'])/2

def regPrune(tree,testData):
    """
    递归后剪枝,需要一定数量的测试集,最好数量和样本集相同
    注意这种剪枝方法适合用于结果是连续型数据(按平均值塌陷不太适合分类,因为类别是固定的)
    """
    # 若无测试集,则做塌陷处理
    if testData.shape[0] == 0:
        return getMean(tree)
    # 判断节点下是否为子树,若为子树则进一步细分处理,直至节点下均为叶节点
    if (isTree(tree['left'])) or (isTree(tree['right'])):
        lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
    if isTree(tree['left']):
        tree['left'] = regPrune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right'] = regPrune(tree['right'],rSet)
    # 当节点下都为叶节点时,判断是否进行合并处理
    if (not isTree(tree['left'])) and (not isTree(tree['right'])):
        lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        # 计算未合并时的误差(混乱程度)
        notMergeErr = sum(np.power(lSet[:,-1]-tree['left'],2)) + sum(np.power(rSet[:,-1]-tree['right'],2))
        treeMerge = (tree['left']+tree['right'])/2
        mergeErr = sum(np.power(testData[:,-1]-treeMerge,2))
        if mergeErr < notMergeErr:
            print("Merging!")
            return treeMerge
        else:
            return tree
    # 若节点下不全为叶节点,则不执行合并剪枝操作
    else:
        return tree

def outJudge(dataSet,tree):
    """
    遍历判断样本外数据类型
    """
    outputData = pd.DataFrame(dataSet)
    classResults = []
    for ii in range(dataSet.shape[0]):
        result = judgeType(dataSet[ii,:].A[0],tree)
        classResults.append(result)
    outputData['classResults'] = classResults
    outputData.to_excel('./data/machine_learning/mytree.xlsx',index=False,encoding='utf-8-sig')
    print("样本外数据分类(判断)完成!")

def judgeType(data,tree):
    """
    递归判断分类
    """
    spInd = tree['spInd']
    spVal = tree['spVal']
    if data[spInd] <= spVal:
        # 若节点下为子树则递归,否则返回叶节点的值
        if isTree(tree['left']):
            return judgeType(data,tree['left'])
        return tree['left']
    else:
        if isTree(tree['right']):
            return judgeType(data,tree['right'])
        return tree['right']

def treeCart(trainDataPath,outDataPath='',testDataPath='',leafType=regLeaf,errType=regErr,ops=(1,4),prune=False,fileType='txt'):
    """
    主函数,参数含义:
    trainDataPath:训练集数据路径
    outDataPath:样本外数据路径
    testDataPath:测试集数据路径,当需要后剪枝操作时需输入
    leafType:创建叶节点方式
    errType:误差(混乱程度)计算方式
    ops:预剪枝参数,第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支
    prune:是否进行后剪枝操作
    fileType:训练集、测试集数据类型(xlsx、txt、csv),txt文件需以制表符\t为分割
    """
    dataMat = loadData(trainDataPath,fileType)
    try:
        myTree = creatTree(dataMat,leafType,errType,ops)
        if prune:
            testData = loadData(testDataPath,fileType)
            myTree = regPrune(myTree,testData)
            print('决策树构建完成!')
            print(myTree)
        else:
            print('决策树构建完成!')
            print(myTree)
        # 预测(分类操作)
        if outDataPath != '':
            outData = loadData(outDataPath,fileType)
            outJudge(outData,myTree)
    except:
        print("检查是否正确输入参数!")
        print('请在函数treeCart中输入叶节点创建方式参数:\n\t1、按平均值创建:leafType=regLeaf\n\t2、按最多样本创建:leafType=maxLeaf')
        print('请在treeCart中输入误差计算方式参数:\n\t1、香农熵:errType=entErr\n\t2、基尼不纯度:errType=giniErr\n\t3、平方误差:regErr')
        print('请在treeCart中输入预剪枝参数ops:\n\t其中第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支')
        print('示例:treeCart(trainDataPath,leafType=regLeaf,errType=regErr,ops=(1,4))')
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值