Python-(分类)决策树学习及实现-2017/02/27

分类决策树(Decision Tree)

  • 学习

    这两天在machinelearningmastery.com上学习Python实现CART(Classify And Regression Tree),把分类树从头到尾学习实现了一遍,虽然不是什么难事,还是想记录一下,就当增强增强记忆也好。

  • 算法思路

    分类树逻辑上即为一些连环判断的组合,以Binary Tree的结构承载这个流程,以存在于非叶节点的数据的属性+值为判断条件,以存于各叶节点的值为判断结果。下图即为一个简单的决策树逻辑(图片来源:machinelearningmastery.com)。
    Easy example of Decision Tree

  • 实现

    算法实现分一下几个部分:
    1、Gini函数
    2、树内各节点的分割
    3、树的建立
    4、预测结果

    Gini函数:
    Gini指数作为loss function用来衡量分组后数据“纯净性”(原文用的purity)的指标,判断数据正确分类的程度:

    G = sum(Pk*(1-Pk))

    其中Pk为第k类在分组数据某一组中的比例,sum()函数将每一组中每一类的结果都累加起来。

#Gini index calculation for loss fuction
def gini_index(groups,classes_val):
    '''
    groups为拟分组后的数据,其中的元素为list,list中为拟归为一类的数据,格式为[attribute1,attribute2,...,class],前面为属性,最后一列为类别;classes_val中包含所有类别class值的list。
    '''
    gini = 0.0
    for class_val in classes_val:
        for group in groups:
            size = len(group)
            if size == 0:
                continue
            proportion = [row[-1] for row in group].count(class_val)/float(size)
            gini += proportion*(1-proportion)
    return gini     

树内各节点的分割:
要将数据分类,首先要知道根据什么指标进行分类,即对于每一步的判断条件,应当找出最适合分类的属性及该属性下最适合的值——get_split()将来自要分割节点的所有数据的所有属性和所有值进行遍历分割,分别计算各拟分组的Gini指数,取能获得最小Gini指数的分割方式对该节点进行分割;test_split()即为每次遍历是根据给定的属性和值对数据集进行分割。

#Split dataset into groups by specific attribute and value
def test_split(dataset,index,value):
    left, right = [], []
    for row in dataset:
        if row[index] < value:
            left.append(row)
        else:
            right.append(row)
    return left, right

#Split dataset into groups for every splitable node 
def get_split(dataset):
    dimen = len(dataset[0])-1
    b_index, b_value, b_gini, b_group = 999, 999, 999, None
    class_values = list(set([row[-1] for row in dataset]))
    for index in range(dimen):
        for row in dataset:
            group = test_split(dataset,index,row[index])
            gini = gini_index(group,class_values)
            if gini < b_gini:
                b_index, b_value, b_gini, b_group = index, row[index], gini, group
    return {'index':b_index, 'value':b_value, 'gini':b_gini, 'groups':b_group}      

树的建立:
在知道如何对每个节点进行合适分割之后,就要开始用递归的方式调用split()函数不断分割节点来建立整棵树。

考虑递归中的基本情况和需要递归的情况:

1、基本情况(节点分割结束,变为叶节点):分割后的节点没有左子节点或右子节点,;当本次分割后树的深度超出最大深度(max_depth,给定);当本次分割后子节点的数据量小于最小分类后数据量(min_size,给定)或子节点已经被完全正确分类(节点内的所有数据为同一类)。

2、需要递归的情况(子节点继续作为下一个父节点调用分割函数)。

#Make a node a terminal 
def to_terminal(group):
    result = [row[-1] for row in group]
    return max(set(result),key=result.count)

#Split the whole tree by iteration
def split(node,max_depth,min_size,depth):
    left, right = node['groups']
    del(node['groups'])
    #check for no left or right
    if not left or not right:
        node['left'] = node['right'] = to_terminal(left + right)
        return
    #check for max depth
    if depth >= max_depth:
        node['left'], node['right'] = to_terminal(left), to_terminal(right)
        return
    #process the left
    if len(left) <= min_size or len(set(row[-1] for row in left)) <= 1:  #check for min size and already splited correctly
        node['left'] = to_terminal(left)
    else:
        node['left'] = get_split(left)
        split(node['left'],max_depth,min_size,depth+1)
    #process the right
    if len(right) <= min_size or len(set(row[-1] for row in right)) <= 1:  #check for min size and already splited correctly
        node['right'] = to_terminal(right)
    else:
        node['right'] = get_split(right)
        split(node['right'],max_depth,min_size,depth+1)

#Build a whole decision tree
def build_tree(dataset,max_depth,min_size):
    root = get_split(dataset)
    split(root,max_depth,min_size,1)
    return root

其中的to_terminal()函数实现将该节点变为叶节点,逻辑为以该节点数据中最大比例的该类作为叶节点的值。build_tree()为封装的建树函数,返回树的根节点。

预测结果:
在训练数据建好决策树之后,对测试数据利用决策树进行预测分类,逻辑即为利用存储在各非叶结点中的一系列判断条件进行从根节点到叶节点的预测:

#Predict the results of a set of data by trained decision tree
def predict(node,row):
    if row[node['index']] < node['value']:
        if isinstance(node['left'],dict):
            return predict(node['left'],row)
        else:
            return node['left']
    else:
        if isinstance(node['right'],dict):
            return predict(node['right'],row)
        else:
            return node['right']

最后将包括训练和预测的函数全部封装到一个decision_tree()函数中,实现算法。

#The (Classify) Decision Tree Algorithm
def decision_tree(train_data,test_data,max_depth,min_size):
    tree_root = build_tree(train_data,max_depth,min_size)
    predicted = []
    for row in test_data:
        predicted.append(predict(tree_root,row))
    return predicted
  • 总结

    分类决策树作为比较经典的机器学习算法,逻辑比较简单,代码量主要在于决策树节点的分割,loss function可以用不同选择(Gini指数,熵),通过minimize loss function来实现最优分割,最后递归选取分割节点来建立完整决策树。

学习与代码参考:machinelearningmastery.com

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值