决策树2 -- CART算法

声明:

       1,本篇为个人对《2012.李航.统计学习方法.pdf》的学习总结。不得用作商用,欢迎转载,但请注明出处(即:本帖地址)

       2,因为本人在学习初始时有非常多数学知识都已忘记。所以为了弄懂当中的内容查阅了非常多资料,所以里面应该会有引用其它帖子的小部分内容。假设原作者看到能够私信我。我会将您的帖子的地址付到以下。

       3,假设有内容错误或不准确欢迎大家指正。

       4,假设能帮到你,那真是太好了。

介绍

         CART是在给定输入变量X条件下,输出随机变量Y的条件概率分布的学习方法。

         CART如果决策树是二叉树,内部节点特征取值为“是”或“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支

这样决策树等价于递归的二分每一个特征(即使数据有多个取值,也把数据分成两部分)

         CART算法由下面两步组成:

                   1)决策树生成:基于训练数据集生成决策树。生成的决策树要尽量大。

                   2)决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树。这是用损失函数最小作为剪枝的标准。

 

CART生成

         决策树的生成即:递归构建二叉决策树的过程

                   对回归树:用平方误差最小化准则。进行特征选择,生成二叉树

                   对分类树:用基尼指数最小化准则,进行特征选择,生成二叉树

         PS:基尼指数(Gini)代表了某一集合的不确定性,Gini越大。样本集合的不确定性就越大。这点和熵相似。

         这里总结分类树。

分类树的生成 – 基尼指数

         首先,须要注意。在计算概率分布的基尼指数时须要考虑不同的情况:

         情况1:

                   如果有N个类,样本点属于第k个类的概率为Pk,则其基尼指数为:

                            

         情况2:

                   对于二分类问题,若样本属于第一个类的概率为P,则概率分布的Gini为:

                            Gini(P)= 2P(1-P)

         情况1的样例:

                   对于给定的样本集合D,Ck是D中属于第k个类的样本子集,N是类的个数,则GIni为:

                            

         在上面2种情况的基础上:

                   若样本集合D依据特征A是否属于某一可能是a被切割成D1和D2两部分,即:

                            D1= { (x, y)∈D | A(x) = a }

                            D2= D - D1

                   那么在特征A的条件下,集合D的Gini为:

                            

                   即:经A = a分隔后集合D的不确定性。

CART生成算法

         描写叙述:

                   输入:

                            训练数据D。停止计算的条件

                   输出:

                            CART决策树

                   解:

                            依据训练数据集,从根节点開始。递归的对每一个节点进行下面操作来构建二叉决策树。

                            1。从节点的训练数据集D计算现有特征对该数据集的基尼指数(Gini)。此时,对每个特征A,对其可能取得每个值a,依据“样本A = a的结果是‘是’或‘否’”将D切割成D1。D2两部分。利用式①计算A = a 时的Gini。

                            2,在全部可能的特征A以及它们全部可能的切分点a中选择Gini最小的特征及其相应的切分点作为最优特征和最优切分点,然后根据最优特征与最优切分点从现节点生成两个子结点,最后将训练数据集根据特征分配到两个子结点中去。

                            3,对两个子结点递归的调用上面两步直至满足停止条件。

                            4。生成CART决策树。

                            PS:算法的停止条件是“节点的样本个数 < 预定阈值”或“样本集合的Gini < 预定阈值(样本基本属于同一类)”或“无很多其它特征”。

 

         样例:

                   对“贷款申请样本数据表”,应用CART算法生成决策树。

                                                                 (贷款申请样本数据表)

ID

年龄

有工作

有自己的房子

信贷情况

类别(是否能贷到款)

1

青年

一般

2

青年

3

青年

4

青年

一般

5

青年

一般

6

中年

一般

7

中年

8

中年

9

中年

很好

10

中年

很好

11

老年

很好

12

老年

13

老年

14

老年

很好

15

老年

一般

                   解:

                            我们先对上表中的各个种类和特征做一个标记:

                                     A1、A2、A3、A4 分别代表:年龄、有无工作、有无房子、信贷情况

                                     1、2、3:表示年龄的青、中、老年

                                     1、2:表示有房、有工作 和 无房、无工作

                                     1、2、3:表示信贷情况很好、好和一般

                            1。求A1的Gini:

                                     

                                     同理:

                                               Gini(D,A1 = 2) = 0.48

                                               Gini(D,A1 = 3) = 0.44

                                     因为Gini(D, A1= 1) 和Gini(D, A1 = 3) 相等且最小。所以A1 = 1和A1 = 3均可作为A1的最优切分点

                            2,而A2和A3仅仅有两个特征,那假设要用A2或A3来分类样本集合的话,那切分点就无疑仅仅有一个了,即A2和A3仅仅有一个切分点

                                     只是尽管A2和A3仅仅有一个切分点。但其Gini还是要计算的,由于要比較全部特征的Gini,从中选一个最小的作为第一次分类的最优特征和最优切分点。

                                     于是:

                                               Gini(D,A2 = 1) = 0.32

                                               Gini(D,A3 = 1) = 0.27

                            3,同理A4的Gini例如以下:

                                     Gini(D,A4 = 1) = 0.36

                                     Gini(D,A4 = 2) = 0.47

                                     Gini(D,A4 = 3) = 0.32

                                     对A4来说,Gini(D, A4= 3) 是A4的最优切分点

                            4, ∵ 在A1,A2。A3。A4中Gini(D, A3 =1) = 0.27最小

                                     ∴ 选择A3为最优特征。A3 = 1为最优切分点。

                                     ∴ 生成根节点A3和切分后的两个子结点(由于对A3来说,全部有房子的都能贷到款。所以对于切分后的两个子结点中的“有房子”这个节点来说已不用在切分(也没法在切分),所以这个节点是叶子节点)

                            5,对还有一个节点(无房子的那个节点)继续使用上面4步在A1,A2,A3中选择最优特征和最优切分点,结果是A2 = 1。且在此计算得知,全部节点均为叶子节点(均被全然分类– 无房子的样本中:有工作的全贷到款,没工作的全贷不到。

PS:这里没使用预定阈值。所以结束条件为:特征全被使用/样本数据被全然分类)

CART剪枝

         有时由于学习到的决策树过于复杂(分的过于细),所以我们须要对决策树进行剪枝,即:通过在底端剪去一些子树。使决策树变小(变简单)。

         剪枝算法有:

                   减少错误剪枝REP(ReducedError Pruning)

                   悲观错误剪枝PEP(Pessimistic ErrorPruning)

                   基本错误剪枝REP(Err-Based Pruning)

                   代价-复杂度剪枝CCP(Cost-ComplexityPruning)

                   最小错误剪枝MEP(Minimum ErrorPruning)

                   最小期望误判成本MECM(MininumExpected Costof Misclassification)

                   最小描写叙述长度MPL(MininumDescription Length)

         这里总结代价-复杂度剪枝CCP(Cost-ComplexityPruning)。

代价-复杂度剪枝

         就像CART生成算法使用基尼指数作为推断标准一样,这个算法使用“误差增益值”作为推断标准。

         于是。又到记公式的时间了,Yeeee….(个鬼啊),对于决策树T的随意内部节点:

                   误差增益值

                   t:决策树T的随意部位节点

                   |NTt|:子树中包括的叶子节点个数。

(注意:是叶子节点)

                   R(t ):节点t的误差代价(假设该结点被剪枝)

                            R(t ) = r( t ) * P( t )

                            r(t ):节点t上的数据占全部数据的比例

                                     eg:某节点的元素中有a个属于目标类,b个不属于,则r( t )= b / (a + b)

                            P(t ):节点t上的数据占全部数据的比例

                                     eg:某节点有x个元素,全部的节点一共同拥有y个元素。则P( t ) = x / y

                   R(Tt ):子树的误差代价,假设该节点不被剪枝。那它等于子树Tt上全部叶子节点的误差代价之和。

         用一个样例说明下上面的公式吧。

                  一个决策树中一共同拥有60个元素,而当中的一个非叶子结点T4。例如以下图所看到的,求其不属于类1的误差增益值

                  

                            PS:上图说明:

                                     节点T7:6个元素属于类1、3个不属于

                                     节点T8:3个元素属于类1、2个不属于

                                     节点T9:2个元素属于类1、0个不属于

                   解:

                            1,求R( t )

                                     对节点T4,由于全部的数据一共60条。所以:

                                               

                            2。求 |NTt|

                                     ∵ 子树T4共同拥有3个叶子节点:T7、T8、T9

                                     ∴ |NTt| = 3

                            3,求R( Tt )

                                     ∵ R( Tt ) = 子树Tt上全部叶子节点的误差代价之和

                                     ∴

                                         

                            4,求g( t )

                                     综上:

                            上面是求一个非叶子节点的过程,面对实际的决策树就是递归求出全部非叶子结点的g( t ) 后找到最小的那个非叶子节点,然后令其左右孩子为NULL。

                            PS:当有多个非叶子结点的g(t ) 同一时候最小时,取|NTt|最大的那个进行剪枝。

         于是,CART剪枝算法例如以下。

CART剪枝算法

         输入:CART算法生成的决策树T0。

         输出:最优决策树Ta。

         解:

                   1,设k = 0,T = T0

                   2,设a = +∞

                   3,自下而上的对各内部节点t计算g(t),然后令a =min(a, g(t) )

                   4,自上而下的訪问内部节点t,若有g(t) = a,则进行剪枝。并对叶子节点t以多数表决法来决定其类,得到树T。

                   5,设k = k + 1。ak = a。Tk = T

                   6,若T不是由根节点单独构成的树,则返回步骤4

                   7,使用交叉验证法再子树序列T0、T1、…、Tn中选最优子树Ta。


#-*-coding:utf-8-*-
# LANG=en_US.UTF-8
# CART 算法
# 文件名称:CART.py
#

import sys
import math
import copy

dict_all = {
        # 1: 青年;2:中年;3:老年
        '_age' : [
                1, 1, 1, 1, 1,
                2, 2, 2, 2, 2,
                3, 3, 3, 3, 3,
            ],

        # 0:无工作;1:有工作
        '_work' : [
                0, 0, 1, 1, 0,
                0, 0, 1, 0, 0,
                0, 0, 1, 1, 0,
            ],

        # 0:无房子;1:有房子
        '_house' : [
                0, 0, 0, 1, 0,
                0, 0, 1, 1, 1,
                1, 1, 0, 0, 0,
            ],

        # 1:信贷情况一般。2:好;3:很好
        '_credit' : [
                1, 2, 2, 1, 1,
                1, 2, 2, 3, 3,
                3, 2, 2, 3, 1,
            ],
    }

# 0:未申请到贷款;1:申请到贷款
_type = [
        0, 0, 1, 1, 0,
        0, 0, 1, 1, 1,
        1, 1, 1, 1, 0,
    ]

# 二叉树结点
class BinaryTreeNode( object ):
    def __init__( self, name=None, data=None, left=None, right=None, father=None ):
        self.name = name
        self.data = data
        self.left = left
        self.right = left
        self.father = father

# 二叉树遍历
class BTree(object):
    def __init__(self,root=0):
        self.root = root

    # 中序遍历
    def inOrder(self,treenode):
        if treenode is None:
            return

        self.inOrder(treenode.left)
        print treenode.name, treenode.data
        self.inOrder(treenode.right)


# 获得种类中中每个特征的个数,以及该特征中_type = 1的个数 和 其它特征中_type = 1的个数
# 输入:字典中的当前种类的字典。列表 _type。待分析种类列表中的元素序号
# 输出字典:{ '特征': [特征的个数, 该特征中_type = 1(能贷到款)的个数, 其它种特征type = 1的个数] }
# eg,对于 _age:
#   由于其青中老年个 5 个,且青年中能带到款的有2个,中年和老年能贷到款的分别为3个和4个,所以输出:
#       {'1': [5, 2, 7], '2': [5, 3, 6], '3': [5, 4, 5]}
def get_value_type_num( _data, _type_list, num_list ):
    value_dict = {}
    tmp_type = ''
    tmp_item = ''

    for num in num_list:
        item = str( _data[num] )
        if tmp_item != item:
            if item in value_dict.keys():
                value_dict[item][0] = value_dict[item][0] + 1
                if _type_list[num] == 1:
                    value_dict[item][1] = value_dict[item][1] + 1
            else:
                if _type_list[num] == 1:
                    value_dict[item] = [1.0, 1.0, 0.0]
                else:
                    value_dict[item] = [1.0, 0.0, 0.0]
                tmp_item = item
        else:
            value_dict[item][0] = value_dict[item][0] + 1
            if _type_list[num] == 1:
                value_dict[item][1] = value_dict[item][1] + 1

    for num1 in xrange( len(value_dict) ):
        for num2 in xrange( len(value_dict) ):
            if num1 == num2: continue
            value_dict[value_dict.keys()[num1]][2] += value_dict[value_dict.keys()[num2]][1]

    return value_dict


# 获得种类中不同特征包括的元素序号
# 如:相应 dict_all 中的 _age,其包括青中老年。若 num_list 为 [0..15]。则输出:
#   {'1': [0, 1, 2, 3, 4], '2': [5, 6, 7, 8, 9], '3': [10, 11, 12, 13, 14]}
def get_value_type_no( data, data_type, num_list ):
    value_dict = {}
    tmp_item = ''

    for num in num_list:
        item = str( data[data_type][num] )
        if tmp_item != item:
            if item in value_dict.keys():
                value_dict[item].append( num )
            else:
                value_dict[item] = [num,]
        else:
            value_dict[item].append( num )

    return value_dict


# 使用 gini 获得最优切分点
def get_cut_point_by_gini( _dict_all, _type_list, num_list, threshold ):
    target_type = ''
    target_feature = ''
    target_gini = 1000000.0

    for data_key in _dict_all:
        value_dict = get_value_type_num( _dict_all[data_key], _type_list, num_list )
        tmp_feature = ''
        gini = 1000000.0
        # 通过计算当前种类的每个特征的 gini 值,来获得该种类中 gini 最小的那个特征
        for value_key in value_dict.keys():
            all_feature_num = len(_dict_all[data_key])
            this_feature_num = value_dict[value_key][0]
            other_feature_num = all_feature_num - this_feature_num
            this_feature_yes_num = value_dict[value_key][1]
            other_feature_yes_num = value_dict[value_key][2]
            # 计算 gini
            tmp_gini = float( '%.2f' % \
                ( \
                    (
                        ( this_feature_num / all_feature_num ) * \
                        2 * \
                        ( this_feature_yes_num / this_feature_num ) * \
                        ( 1 - this_feature_yes_num / this_feature_num ) \
                    ) + \
                    (
                        ( other_feature_num / all_feature_num ) * \
                        2 * \
                        ( other_feature_yes_num / other_feature_num ) * \
                        ( 1 - other_feature_yes_num / other_feature_num ) \
                    ) \
                ) )
            # 获得该种类中 gini 最小的那个特征
            if float(gini) - tmp_gini > 0.0:
                gini = tmp_gini
                tmp_feature = value_key

            if gini < threshold:
                return data_key, tmp_feature, 'over'
        
        # 通过对照所有种类中 gini 最小的特征,来获得 gini 最小的特征的种类, 该种类以及该种类的特征就是切分点
        if float(target_gini) - float(gini) > 0.0:
            target_type = data_key
            target_feature = tmp_feature

    return target_type, target_feature, 'continue'


# CART 算法
def CART( data, type_list, threshold ):
    # 进行分类
    def classify( root, note_name, note_data, note_type ):
        # 将'特征可能值名字'追加到 root.name 中
        # 将[样本序号的列表]合并到 root.data 中
        root.name.append( note_name )
        root.data.extend( note_data )

        # note_type=='exit' 意味着当前的数据所有属于某一类,不用在分类了
        if not data or note_type == 'exit':
            return

        target_type, target_feature, step = get_cut_point_by_gini( data, type_list, note_data, threshold )

        feature_dict = get_value_type_no( data, target_type, note_data )

        # 从样本集合中将该特征删除
        del data[target_type]

        # 准备左子节点和右子节点,节点的 name 和 data 是个空列表
        root.left = BinaryTreeNode( [], [] )
        root.right = BinaryTreeNode( [], [] )

        # 计算“特征字典”中各个集合中是属于“能贷贷款”的多还是“不能贷贷款”的多
        # 假设是前者:
        #   递归调用 classify,形成左子节点
        # 假设是后者:
        #   递归调用 classify。形成右子节点
        for key in feature_dict.keys():
            num_yes = 0; num_no = 0
            for num in feature_dict[key]:
                if type_list[num] == 1:
                    num_yes = num_yes + 1
                elif type_list[num] == 0:
                    num_no = num_no + 1
                else:
                    print 'ERROR: wrong type in _type'
                    exit()

            note_type = 'not_exit'
            if num_yes == 0 or num_no == 0 or step == 'over':
                note_type = 'exit'
            
            if key == target_feature:
                classify( root.left, '%s:%s' % (target_type, key), feature_dict[key], note_type )
            else:
                classify( root.right, '%s:%s' % (target_type, key), feature_dict[key], note_type )
        
        return root


    tmp_list = []
    for num in xrange( len(dict_all[dict_all.keys()[0]]) ):
        tmp_list.append( num )
    return classify( BinaryTreeNode( [], [] ), 'root', tmp_list, 'not_exit' )


class cost_complexity_pruning_parm( object ):
    def __init__( self, sum_num ):
        # 一共同拥有多少个元素
        self.sum_num = sum_num
        # 某个节点的元素数
        self.node_num = 0.0
        # 某节点的叶子节点数量
        self.leaf_node_num = 0.0
        # 某节点的"错误分类"的元素数量
        self.node_data_error_num = 0.0
        # R(Tt)
        self.Rtt = 0.0
        # 节点的误差率增益值 g(t) 的字典,格式是{'节点名字': 节点的误差率增益值}
        self.error_rate_gain_dict = {}

    # 计算 R(Tt)
    # 參数:self, 该节点的"错误分类"的元素数量, 该节点的元素数
    def count_Rtt( self, node_item_num, node_err_num ):
        self.Rtt = self.Rtt + ( (node_err_num/node_item_num) * (node_item_num/self.sum_num) )

    # 制作误差率增益值 g(t) 的字典
    # g(t) = R(t) - R(Tt) / ( |NTt| - 1 )
    # 參数:self, key, 该节点的"错误分类"的元素数量, 该节点的元素数
    def make_error_rate_gain_value_dict( self, key, node_item_num, node_err_num ):
        rt = node_err_num / self.sum_num
        pt = node_item_num / self.sum_num
        Rt = rt * pt
        NTt = self.leaf_node_num
        self.error_rate_gain_dict[key] = float( '%.3f' % float((Rt-self.Rtt)/(NTt-1)) )

    def print_error_rate_gain( self ):
        print self.error_rate_gain_dict

def get_error_rate_gain_dict( dict_all_pruning, type_list, tree, cls ):
    # 对某个节点求其误差率增益值
    def analyze_node( node, node_name, cls ):
        # 假设是叶子节点。则叶子节点数 + 1,并计算 R(Tt)
        if not node.left and not node.right:
            cls.leaf_node_num = cls.leaf_node_num + 1
            dict_key = node.name[0].split(':')[0]
            value_dict = get_value_type_num( dict_all_pruning[dict_key], type_list, node.data )
            dict_key = node.name[0].split(':')[1]
            cls.count_Rtt( value_dict[dict_key][0], value_dict[dict_key][0] - value_dict[dict_key][1] )
            return

        # 兴许遍历
        analyze_node( node.left, None, cls )
        analyze_node( node.right, None, cls )
        # 假设遍历到 back_order 传进来的 node。则计算其“误差率增益值”
        if node.name[0] == node_name:
            dict_key = node.name[0].split(':')[0]
            # 获得 get_value_type_num 返回的字典(里面包括了该节点的元素总数和"正确分类"的元素数)
            value_dict = get_value_type_num( dict_all_pruning[dict_key], type_list, node.data )

            # 计算"错误分类"的元素数
            dict_key = node.name[0].split(':')[1]
            cls.make_error_rate_gain_value_dict( node.name[0], value_dict[dict_key][0], value_dict[dict_key][0] - value_dict[dict_key][1] )
            return cls.leaf_node_num

    # 兴许遍历决策树
    def back_order( node, cls ):
        # 假设是叶子节点。则返回
        if not node.left and not node.right: return

        back_order( node.left, cls )
        back_order( node.right, cls )
        # 假设是根节点,则返回
        if node.name[0] == 'root': return

        cls.leaf_node_num = 0
        # 反之。求该结点的误差率增益值
        analyze_node( node, node.name[0], cls )

    back_order( tree.root, cls )


def cost_complexity_pruning( dict_all_pruning, type_list, tree, cls ):
    # 进行剪枝
    def pruning( node, target_node_name ):
        if not node.left and not node.right: return
        if node.name[0] == target_node_name:
            node.left = None
            node.right = None
            return

        pruning( node.left, target_node_name )
        pruning( node.right, target_node_name )

    # 获得误差率增益值 g(t) 的字典
    get_error_rate_gain_dict( dict_all_pruning, type_list, tree, cls )
    #cls.print_error_rate_gain()

    # 找出误差率增益值最小的节点
    min_error_rate_gain = 10000.0
    min_error_rate_gain_node = ''
    for key in cls.error_rate_gain_dict.keys():
        error_rate_gain = cls.error_rate_gain_dict[key]
        if error_rate_gain < min_error_rate_gain:
            min_error_rate_gain = error_rate_gain
            min_error_rate_gain_node = key

    pruning( tree.root, min_error_rate_gain_node )


# 阈值
# 假设使用 threshold = 0.3。那在使用 house 将样本数据分类后就停止了
# threshold = 0.3
threshold = 0.1
dict_all_cart = copy.deepcopy( dict_all )
root = CART( dict_all_cart, _type, threshold )
bt = BTree( root )
bt.inOrder( bt.root )
print '\n--------------\n'
# 这一步应该使用训练数据
dict_all_pruning = copy.deepcopy( dict_all )
cost_complexity_pruning( dict_all_pruning, _type, bt, cost_complexity_pruning_parm(len(dict_all_pruning[dict_all_pruning.keys()[0]])) )
bt.inOrder( bt.root )

# 剪枝前
#       root
#       /  \
# house:1  house:0
#           /  \
#      work:1  work:0
#
# 剪枝后(由于仅仅有一个非叶子节点"house:0",所以仅仅能剪这个节点了)
#       root
#       /  \
# house:1  house:0
# 当然,这里剪这个不适合。由于剪枝前的决策树既不复杂也全然划分了样本数据,只是这里仅仅是实现剪枝算法,所以不考虑决策树适不适合剪枝。
# 顺便一提,"剪枝前的决策树在未用完种类的情况下全然划分了样本数据"能够作为适不适合剪枝的推断条件之中的一个。


转载于:https://www.cnblogs.com/zhchoutai/p/7085318.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值