决策树之CART算法

原文转自:点击打开链接


介绍

         CART是在给定输入变量X条件下,输出随机变量Y的条件概率分布的学习方法。

         CART假设决策树是二叉树,内部节点特征取值为“是”或“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支。这样决策树等价于递归的二分每个特征(即使数据有多个取值,也把数据分成两部分)

         CART算法由以下两步组成:

                   1)决策树生成:基于训练数据集生成决策树,生成的决策树要尽量大。

                   2)决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这是用损失函数最小作为剪枝的标准。

 

CART生成

         决策树的生成即:递归构建二叉决策树的过程

                   对回归树:用平方误差最小化准则,进行特征选择,生成二叉树

                   对分类树:用基尼指数最小化准则,进行特征选择,生成二叉树

         PS:基尼指数(Gini)代表了某一集合的不确定性,Gini越大,样本集合的不确定性就越大,这点和熵相似。

         这里总结分类树。

分类树的生成 – 基尼指数

         首先,需要注意,在计算概率分布的基尼指数时需要考虑不同的情况:

         情况1:

                   假设有N个类,样本点属于第k个类的概率为Pk,则其基尼指数为:

                            

         情况2:

                   对于二分类问题,若样本属于第一个类的概率为P,则概率分布的Gini为:

                            Gini(P)= 2P(1-P)

         情况1的例子:

                   对于给定的样本集合D,Ck是D中属于第k个类的样本子集,N是类的个数,则GIni为:

                            

         在上面2种情况的基础上:

                   若样本集合D根据特征A是否属于某一可能是a被分割成D1和D2两部分,即:

                            D1= { (x, y)∈D | A(x) = a }

                            D2= D - D1

                   那么在特征A的条件下,集合D的Gini为:

                            

                   即:经A = a分隔后集合D的不确定性。

CART生成算法

         描述:

                   输入:

                            训练数据D,停止计算的条件

                   输出:

                            CART决策树

                   解:

                            根据训练数据集,从根节点开始,递归的对每个节点进行以下操作来构建二叉决策树。

                            1,从节点的训练数据集D计算现有特征对该数据集的基尼指数(Gini)。此时,对每一个特征A,对其可能取得每一个值a,根据“样本A = a的结果是‘是’或‘否’”将D分割成D1,D2两部分,利用式①计算A = a 时的Gini。

                            2,在所有可能的特征A以及它们所有可能的切分点a中选择Gini最小的特征及其对应的切分点作为最优特征和最优切分点,然后依据最优特征与最优切分点从现节点生成两个子结点,最后将训练数据集依据特征分配到两个子结点中去。

                            3,对两个子结点递归的调用上面两步直至满足停止条件。

                            4,生成CART决策树。

                            PS:算法的停止条件是“节点的样本个数 < 预定阈值”或“样本集合的Gini < 预定阈值(样本基本属于同一类)”或“无更多特征”。

 

         例子:

                   对“贷款申请样本数据表”,应用CART算法生成决策树。

                                                                 (贷款申请样本数据表)

ID

年龄

有工作

有自己的房子

信贷情况

类别(能否贷到款)

1

青年

一般

2

青年

3

青年

4

青年

一般

5

青年

一般

6

中年

一般

7

中年

8

中年

9

中年

非常好

10

中年

非常好

11

老年

非常好

12

老年

13

老年

14

老年

非常好

15

老年

一般

                   解:

                            我们先对上表中的各个种类和特征做一个标记:

                                     A1、A2、A3、A4 分别代表:年龄、有无工作、有无房子、信贷情况

                                     1、2、3:表示年龄的青、中、老年

                                     1、2:表示有房、有工作 和 无房、无工作

                                     1、2、3:表示信贷情况非常好、好和一般

                            1,求A1的Gini:

                                     

                                     同理:

                                               Gini(D,A1 = 2) = 0.48

                                               Gini(D,A1 = 3) = 0.44

                                     由于Gini(D, A1= 1) 和Gini(D, A1 = 3) 相等且最小,所以A1 = 1和A1 = 3均可作为A1的最优切分点

                            2,而A2和A3只有两个特征,那如果要用A2或A3来分类样本集合的话,那切分点就无疑只有一个了,即A2和A3只有一个切分点

                                     不过虽然A2和A3只有一个切分点,但其Gini还是要计算的,因为要比较所有特征的Gini,从中选一个最小的作为第一次分类的最优特征和最优切分点。

                                     于是:

                                               Gini(D,A2 = 1) = 0.32

                                               Gini(D,A3 = 1) = 0.27

                            3,同理A4的Gini如下:

                                     Gini(D,A4 = 1) = 0.36

                                     Gini(D,A4 = 2) = 0.47

                                     Gini(D,A4 = 3) = 0.32

                                     对A4来说,Gini(D, A4= 3) 是A4的最优切分点

                            4, ∵ 在A1,A2,A3,A4中Gini(D, A3 =1) = 0.27最小

                                     ∴ 选择A3为最优特征,A3 = 1为最优切分点。

                                     ∴ 生成根节点A3和切分后的两个子结点(因为对A3来说,所有有房子的都能贷到款,所以对于切分后的两个子结点中的“有房子”这个节点来说已不用在切分(也没法在切分),所以这个节点是叶子节点)

                            5,对另一个节点(无房子的那个节点)继续使用上面4步在A1,A2,A3中选择最优特征和最优切分点,结果是A2 = 1,且在此计算得知,所有节点均为叶子节点(均被完全分类– 无房子的样本中:有工作的全贷到款,没工作的全贷不到。PS:这里没使用预定阈值,所以结束条件为:特征全被使用/样本数据被完全分类)

CART剪枝

         有时因为学习到的决策树过于复杂(分的过于细),所以我们需要对决策树进行剪枝,即:通过在底端剪去一些子树,使决策树变小(变简单)。

         剪枝算法有:

                   降低错误剪枝REP(ReducedError Pruning)

                   悲观错误剪枝PEP(Pessimistic ErrorPruning)

                   基本错误剪枝REP(Err-Based Pruning)

                   代价-复杂度剪枝CCP(Cost-ComplexityPruning)

                   最小错误剪枝MEP(Minimum ErrorPruning)

                   最小期望误判成本MECM(MininumExpected Costof Misclassification)

                   最小描述长度MPL(MininumDescription Length)

         这里总结代价-复杂度剪枝CCP(Cost-ComplexityPruning)。

代价-复杂度剪枝

         就像CART生成算法使用基尼指数作为判断标准一样,这个算法使用“误差增益值”作为判断标准。

         于是,又到记公式的时间了,Yeeee….(个鬼啊),对于决策树T的任意内部节点:

                   误差增益值

                   t:决策树T的任意部位节点

                   |NTt|:子树中包含的叶子节点个数。(注意:是叶子节点)

                   R(t ):节点t的误差代价(如果该结点被剪枝)

                            R(t ) = r( t ) * P( t )

                            r(t ):节点t上的数据占所有数据的比例

                                     eg:某节点的元素中有a个属于目标类,b个不属于,则r( t )= b / (a + b)

                            P(t ):节点t上的数据占所有数据的比例

                                     eg:某节点有x个元素,所有的节点一共有y个元素,则P( t ) = x / y

                   R(Tt ):子树的误差代价,如果该节点不被剪枝,那它等于子树Tt上所有叶子节点的误差代价之和。

         用一个例子说明下上面的公式吧。

                  一个决策树中一共有60个元素,而其中的一个非叶子结点T4,如下图所示,求其不属于类1的误差增益值

                  

                            PS:上图说明:

                                     节点T7:6个元素属于类1、3个不属于

                                     节点T8:3个元素属于类1、2个不属于

                                     节点T9:2个元素属于类1、0个不属于

                   解:

                            1,求R( t )

                                     对节点T4,因为所有的数据一共60条,所以:

                                               

                            2,求 |NTt|

                                     ∵ 子树T4共有3个叶子节点:T7、T8、T9

                                     ∴ |NTt| = 3

                            3,求R( Tt )

                                     ∵ R( Tt ) = 子树Tt上所有叶子节点的误差代价之和

                                     ∴

                                         

                            4,求g( t )

                                     综上:

                            上面是求一个非叶子节点的过程,面对实际的决策树就是递归求出所有非叶子结点的g( t ) 后找到最小的那个非叶子节点,然后令其左右孩子为NULL。

                            PS:当有多个非叶子结点的g(t ) 同时最小时,取|NTt|最大的那个进行剪枝。

         于是,CART剪枝算法如下。

CART剪枝算法

         输入:CART算法生成的决策树T0。

         输出:最优决策树Ta。

         解:

                   1,设k = 0,T = T0

                   2,设a = +∞

                   3,自下而上的对各内部节点t计算g(t),然后令a =min(a, g(t) )

                   4,自上而下的访问内部节点t,若有g(t) = a,则进行剪枝,并对叶子节点t以多数表决法来决定其类,得到树T。

                   5,设k = k + 1,ak = a,Tk = T

                   6,若T不是由根节点单独构成的树,则返回步骤4

                   7,使用交叉验证法再子树序列T0、T1、…、Tn中选最优子树Ta。

#-*-coding:utf-8-*-
# LANG=en_US.UTF-8
# CART 算法
# 文件名:CART.py
#

import sys
import math
import copy

dict_all = {
        # 1: 青年;2:中年;3:老年
        '_age' : [
                1, 1, 1, 1, 1,
                2, 2, 2, 2, 2,
                3, 3, 3, 3, 3,
            ],

        # 0:无工作;1:有工作
        '_work' : [
                0, 0, 1, 1, 0,
                0, 0, 1, 0, 0,
                0, 0, 1, 1, 0,
            ],

        # 0:无房子;1:有房子
        '_house' : [
                0, 0, 0, 1, 0,
                0, 0, 1, 1, 1,
                1, 1, 0, 0, 0,
            ],

        # 1:信贷情况一般;2:好;3:非常好
        '_credit' : [
                1, 2, 2, 1, 1,
                1, 2, 2, 3, 3,
                3, 2, 2, 3, 1,
            ],
    }

# 0:未申请到贷款;1:申请到贷款
_type = [
        0, 0, 1, 1, 0,
        0, 0, 1, 1, 1,
        1, 1, 1, 1, 0,
    ]

# 二叉树结点
class BinaryTreeNode( object ):
    def __init__( self, name=None, data=None, left=None, right=None, father=None ):
        self.name = name
        self.data = data
        self.left = left
        self.right = left
        self.father = father

# 二叉树遍历
class BTree(object):
    def __init__(self,root=0):
        self.root = root

    # 中序遍历
    def inOrder(self,treenode):
        if treenode is None:
            return

        self.inOrder(treenode.left)
        print treenode.name, treenode.data
        self.inOrder(treenode.right)


# 获得种类中中每个特征的个数,以及该特征中_type = 1的个数 和 其他特征中_type = 1的个数
# 输入:字典中的当前种类的字典,列表 _type,待分析种类列表中的元素序号
# 输出字典:{ '特征': [特征的个数, 该特征中_type = 1(能贷到款)的个数, 其他种特征type = 1的个数] }
# eg,对于 _age:
#   因为其青中老年个 5 个,且青年中能带到款的有2个,中年和老年能贷到款的分别为3个和4个,所以输出:
#       {'1': [5, 2, 7], '2': [5, 3, 6], '3': [5, 4, 5]}
def get_value_type_num( _data, _type_list, num_list ):
    value_dict = {}
    tmp_type = ''
    tmp_item = ''

    for num in num_list:
        item = str( _data[num] )
        if tmp_item != item:
            if item in value_dict.keys():
                value_dict[item][0] = value_dict[item][0] + 1
                if _type_list[num] == 1:
                    value_dict[item][1] = value_dict[item][1] + 1
            else:
                if _type_list[num] == 1:
                    value_dict[item] = [1.0, 1.0, 0.0]
                else:
                    value_dict[item] = [1.0, 0.0, 0.0]
                tmp_item = item
        else:
            value_dict[item][0] = value_dict[item][0] + 1
            if _type_list[num] == 1:
                value_dict[item][1] = value_dict[item][1] + 1

    for num1 in xrange( len(value_dict) ):
        for num2 in xrange( len(value_dict) ):
            if num1 == num2: continue
            value_dict[value_dict.keys()[num1]][2] += value_dict[value_dict.keys()[num2]][1]

    return value_dict


# 获得种类中不同特征包含的元素序号
# 如:对应 dict_all 中的 _age,其包含青中老年,若 num_list 为 [0..15],则输出:
#   {'1': [0, 1, 2, 3, 4], '2': [5, 6, 7, 8, 9], '3': [10, 11, 12, 13, 14]}
def get_value_type_no( data, data_type, num_list ):
    value_dict = {}
    tmp_item = ''

    for num in num_list:
        item = str( data[data_type][num] )
        if tmp_item != item:
            if item in value_dict.keys():
                value_dict[item].append( num )
            else:
                value_dict[item] = [num,]
        else:
            value_dict[item].append( num )

    return value_dict


# 使用 gini 获得最优切分点
def get_cut_point_by_gini( _dict_all, _type_list, num_list, threshold ):
    target_type = ''
    target_feature = ''
    target_gini = 1000000.0

    for data_key in _dict_all:
        value_dict = get_value_type_num( _dict_all[data_key], _type_list, num_list )
        tmp_feature = ''
        gini = 1000000.0
        # 通过计算当前种类的每一个特征的 gini 值,来获得该种类中 gini 最小的那个特征
        for value_key in value_dict.keys():
            all_feature_num = len(_dict_all[data_key])
            this_feature_num = value_dict[value_key][0]
            other_feature_num = all_feature_num - this_feature_num
            this_feature_yes_num = value_dict[value_key][1]
            other_feature_yes_num = value_dict[value_key][2]
            # 计算 gini
            tmp_gini = float( '%.2f' % \
                ( \
                    (
                        ( this_feature_num / all_feature_num ) * \
                        2 * \
                        ( this_feature_yes_num / this_feature_num ) * \
                        ( 1 - this_feature_yes_num / this_feature_num ) \
                    ) + \
                    (
                        ( other_feature_num / all_feature_num ) * \
                        2 * \
                        ( other_feature_yes_num / other_feature_num ) * \
                        ( 1 - other_feature_yes_num / other_feature_num ) \
                    ) \
                ) )
            # 获得该种类中 gini 最小的那个特征
            if float(gini) - tmp_gini > 0.0:
                gini = tmp_gini
                tmp_feature = value_key

            if gini < threshold:
                return data_key, tmp_feature, 'over'
        
        # 通过对比所有种类中 gini 最小的特征,来获得 gini 最小的特征的种类, 该种类以及该种类的特征就是切分点
        if float(target_gini) - float(gini) > 0.0:
            target_type = data_key
            target_feature = tmp_feature

    return target_type, target_feature, 'continue'


# CART 算法
def CART( data, type_list, threshold ):
    # 进行分类
    def classify( root, note_name, note_data, note_type ):
        # 将'特征可能值名字'追加到 root.name 中
        # 将[样本序号的列表]合并到 root.data 中
        root.name.append( note_name )
        root.data.extend( note_data )

        # note_type=='exit' 意味着当前的数据全部属于某一类,不用在分类了
        if not data or note_type == 'exit':
            return

        target_type, target_feature, step = get_cut_point_by_gini( data, type_list, note_data, threshold )

        feature_dict = get_value_type_no( data, target_type, note_data )

        # 从样本集合中将该特征删除
        del data[target_type]

        # 准备左子节点和右子节点,节点的 name 和 data 是个空列表
        root.left = BinaryTreeNode( [], [] )
        root.right = BinaryTreeNode( [], [] )

        # 计算“特征字典”中各个集合中是属于“能贷贷款”的多还是“不能贷贷款”的多
        # 如果是前者:
        #   递归调用 classify,形成左子节点
        # 如果是后者:
        #   递归调用 classify,形成右子节点
        for key in feature_dict.keys():
            num_yes = 0; num_no = 0
            for num in feature_dict[key]:
                if type_list[num] == 1:
                    num_yes = num_yes + 1
                elif type_list[num] == 0:
                    num_no = num_no + 1
                else:
                    print 'ERROR: wrong type in _type'
                    exit()

            note_type = 'not_exit'
            if num_yes == 0 or num_no == 0 or step == 'over':
                note_type = 'exit'
            
            if key == target_feature:
                classify( root.left, '%s:%s' % (target_type, key), feature_dict[key], note_type )
            else:
                classify( root.right, '%s:%s' % (target_type, key), feature_dict[key], note_type )
        
        return root


    tmp_list = []
    for num in xrange( len(dict_all[dict_all.keys()[0]]) ):
        tmp_list.append( num )
    return classify( BinaryTreeNode( [], [] ), 'root', tmp_list, 'not_exit' )


class cost_complexity_pruning_parm( object ):
    def __init__( self, sum_num ):
        # 一共有多少个元素
        self.sum_num = sum_num
        # 某个节点的元素数
        self.node_num = 0.0
        # 某节点的叶子节点数量
        self.leaf_node_num = 0.0
        # 某节点的"错误分类"的元素数量
        self.node_data_error_num = 0.0
        # R(Tt)
        self.Rtt = 0.0
        # 节点的误差率增益值 g(t) 的字典,格式是{'节点名字': 节点的误差率增益值}
        self.error_rate_gain_dict = {}

    # 计算 R(Tt)
    # 参数:self, 该节点的"错误分类"的元素数量, 该节点的元素数
    def count_Rtt( self, node_item_num, node_err_num ):
        self.Rtt = self.Rtt + ( (node_err_num/node_item_num) * (node_item_num/self.sum_num) )

    # 制作误差率增益值 g(t) 的字典
    # g(t) = R(t) - R(Tt) / ( |NTt| - 1 )
    # 参数:self, key, 该节点的"错误分类"的元素数量, 该节点的元素数
    def make_error_rate_gain_value_dict( self, key, node_item_num, node_err_num ):
        rt = node_err_num / self.sum_num
        pt = node_item_num / self.sum_num
        Rt = rt * pt
        NTt = self.leaf_node_num
        self.error_rate_gain_dict[key] = float( '%.3f' % float((Rt-self.Rtt)/(NTt-1)) )

    def print_error_rate_gain( self ):
        print self.error_rate_gain_dict

def get_error_rate_gain_dict( dict_all_pruning, type_list, tree, cls ):
    # 对某个节点求其误差率增益值
    def analyze_node( node, node_name, cls ):
        # 如果是叶子节点,则叶子节点数 + 1,并计算 R(Tt)
        if not node.left and not node.right:
            cls.leaf_node_num = cls.leaf_node_num + 1
            dict_key = node.name[0].split(':')[0]
            value_dict = get_value_type_num( dict_all_pruning[dict_key], type_list, node.data )
            dict_key = node.name[0].split(':')[1]
            cls.count_Rtt( value_dict[dict_key][0], value_dict[dict_key][0] - value_dict[dict_key][1] )
            return

        # 后续遍历
        analyze_node( node.left, None, cls )
        analyze_node( node.right, None, cls )
        # 如果遍历到 back_order 传进来的 node,则计算其“误差率增益值”
        if node.name[0] == node_name:
            dict_key = node.name[0].split(':')[0]
            # 获得 get_value_type_num 返回的字典(里面包含了该节点的元素总数和"正确分类"的元素数)
            value_dict = get_value_type_num( dict_all_pruning[dict_key], type_list, node.data )

            # 计算"错误分类"的元素数
            dict_key = node.name[0].split(':')[1]
            cls.make_error_rate_gain_value_dict( node.name[0], value_dict[dict_key][0], value_dict[dict_key][0] - value_dict[dict_key][1] )
            return cls.leaf_node_num

    # 后续遍历决策树
    def back_order( node, cls ):
        # 如果是叶子节点,则返回
        if not node.left and not node.right: return

        back_order( node.left, cls )
        back_order( node.right, cls )
        # 如果是根节点,则返回
        if node.name[0] == 'root': return

        cls.leaf_node_num = 0
        # 反之,求该结点的误差率增益值
        analyze_node( node, node.name[0], cls )

    back_order( tree.root, cls )


def cost_complexity_pruning( dict_all_pruning, type_list, tree, cls ):
    # 进行剪枝
    def pruning( node, target_node_name ):
        if not node.left and not node.right: return
        if node.name[0] == target_node_name:
            node.left = None
            node.right = None
            return

        pruning( node.left, target_node_name )
        pruning( node.right, target_node_name )

    # 获得误差率增益值 g(t) 的字典
    get_error_rate_gain_dict( dict_all_pruning, type_list, tree, cls )
    #cls.print_error_rate_gain()

    # 找出误差率增益值最小的节点
    min_error_rate_gain = 10000.0
    min_error_rate_gain_node = ''
    for key in cls.error_rate_gain_dict.keys():
        error_rate_gain = cls.error_rate_gain_dict[key]
        if error_rate_gain < min_error_rate_gain:
            min_error_rate_gain = error_rate_gain
            min_error_rate_gain_node = key

    pruning( tree.root, min_error_rate_gain_node )


# 阈值
# 如果使用 threshold = 0.3,那在使用 house 将样本数据分类后就停止了
# threshold = 0.3
threshold = 0.1
dict_all_cart = copy.deepcopy( dict_all )
root = CART( dict_all_cart, _type, threshold )
bt = BTree( root )
bt.inOrder( bt.root )
print '\n--------------\n'
# 这一步应该使用训练数据
dict_all_pruning = copy.deepcopy( dict_all )
cost_complexity_pruning( dict_all_pruning, _type, bt, cost_complexity_pruning_parm(len(dict_all_pruning[dict_all_pruning.keys()[0]])) )
bt.inOrder( bt.root )

# 剪枝前
#       root
#       /  \
# house:1  house:0
#           /  \
#      work:1  work:0
#
# 剪枝后(因为只有一个非叶子节点"house:0",所以只能剪这个节点了)
#       root
#       /  \
# house:1  house:0
# 当然,这里剪这个不适合,因为剪枝前的决策树既不复杂也完全划分了样本数据,不过这里仅仅是实现剪枝算法,所以不考虑决策树适不适合剪枝。
# 顺便一提,"剪枝前的决策树在未用完种类的情况下完全划分了样本数据"可以作为适不适合剪枝的判断条件之一。


  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值