决策树算法

1.算法介绍

分类算法有很多种,比如贝叶斯算法、神经网络、支持向量机、决策树算法。说句老实话,神经网络、支持向量机算法我目前也不清楚,但决定在这段时间研究一下。

言归正传,本篇文章想要介绍的算法是决策树算法。所谓决策树算法,即:

对于一个数据集D,它有n个元素,每个元素有m个属性。通过决策树,可以通过这些元素的属性,得到这个元素所属的类别。决策树算法包括两个过程,构造树的过程,已经利用决策树进行分类的过程。

还是按老规矩,举个例子:

比如,已知:一个好的程序员,他应该热爱学习新的技术,写的代码质量高,会2种以上的编程语言。

根据这个条件,我们就可以新建一棵决策树:


这棵决策树比较简单,它的每种属性判定只有2个分支。而真正的决策树,某个属性的判定可能有n个分支(n>=2)。

不管怎么说,假设有一个人,他叫小明,写的代码很高,但是他不喜欢学习新的技术,对于现在各式各样的大数据解决方案、JS库的出现惶恐不已,那么根据这棵决策树判定,他不算一个好的程序员。

2.算法核心

从上面的描述可以看出,决策树算法的难点在于决策树的构建。一旦决策树构建好了,用它进行判定就会非常容易了。

但是用于构建决策树的数据集D含有的元素,含有n种属性,如何选择哪个属性作为根节点,选择哪个属性作为第二层分支的根节点呢?

这就牵扯到香农熵。

有这样一个理论:一个数据集发生变化时,如果它的熵也发生了变化,则称为信息增益。如果信息增益变化越大,则这个变化越好。

我们在构建决策树时,可以计算用哪个属性进行划分时,导致的熵变化大,这样就可以确定出树每一层的属性。

确定某一层的属性后,该属性的分支根节点其实就是这个属性对应的不同的值;分支根节点下方的子树,则是特定数据除去该属性后剩下的属性,递归用相同方法进行构建树。可以看第三小节,建立决策树就是一个递归的过程。

建立完树之后,对于测试数据的判定,就是根据测试数据,遍历树,直至找到叶子节点(类别),从而得到分类结果。

3.算法源码

源码思路参考了《机器学习实战》

#coding:utf-8
import math
import types
#数据集
samples = [
    [1,1,'yes'],
    [1,1,'yes'],
    [1,0,'no'],
    [0,1,'no'],
    [0,1,'no']
]
#计算某个数据集的香农熵
def cal_shannon(dataset):
    classes = [data[-1] for data in dataset]
    distinct_classes = set(classes)
    record = dict()
    for curclass in classes:
        if curclass not in record.keys():
            record[curclass] = 1
        else:
            record[curclass] += 1
    result = 0.0
    for key in record:
        prob = float(record[key])/len(classes)
        result -= prob*math.log(prob,2)
    return result
#过滤数据集中,某个index下标值等于value的集合,得到过滤后的集合
def filter_set(dataset,index,value):
    result = list()
    for data in dataset:
        if data[index] == value:
            temp = data[0:index]
            temp.extend(data[index+1:])
            result.append(temp)
    return result
#计算数据集按照哪个index进行划分最好
def get_best_split(dataset):
    best_index = -1
    base_shannon = cal_shannon(dataset)
    best_gain = 0.0
    for i in range(0,len(dataset[0])-1):
        dist_val = set([data[i] for data in samples])
        cur_shannon = 0.0
        for cur_val in dist_val:
            subset = filter_set(dataset,i,cur_val)
            prob = float(len(subset))/len(dataset)
            cur_shannon += prob*cal_shannon(subset)
        cur_gain = cur_shannon - base_shannon
        if abs(cur_gain) > best_gain:
            best_gain = abs(cur_gain)
            best_index = i
    return best_index
#获得集合中数量最多的classes
def get_max_class(list):
    dic = dict()
    for key in list:
        if key not in dic.keys():
            dic[key] = 1
        else:
            dic[key] += 1
    rtn = sorted(dic.iteritems(),key=lambda dic:dic[1],reverse=True)
    return rtn[0][0] 
#构建决策树,递归构建
def create_tree(dataset,keys):
    best_index = get_best_split(dataset)
    #如果dataset的类都相同,则返回所属的类(叶子节点)
    classes = [data[-1] for data in dataset]
    if classes.count(classes[0]) == len(classes):
        return classes[0]
    #如果所有属性都被遍历完了,则返回
    if len(dataset[0]) == 1:
        return get_max_class(classes)
    #此index对应的dataset的值分为几类
    best_key = keys[best_index]
    vals = set([data[best_index] for data in dataset])
    node = {best_key:{}}
    for val in vals:
        newkeys = keys[0:best_index]
        newkeys.extend(keys[best_index+1:])
        subset = filter_set(dataset,best_index,val)
        node[best_key][val] = create_tree(subset,newkeys)
    return node
#决策树进行判断
def judge(tree,labels,vals):
    # 因为需要判断的元素的labels的顺序不一定与决策树相同
    # 所以需要判断出决策树的当前分支的根节点,它是需要判断的元素的哪个属性
    node_key = tree.keys()[0]
    node_child = tree[node_key]
    index = labels.index(node_key)
    for key in node_child.keys():
        if key == vals[index]:
            node_child_child = node_child[key]
            if type(node_child_child) == types.StringType:
                curclass = node_child_child
            else:
                curclass = judge(node_child_child,labels,vals)   
    return curclass
if __name__ == '__main__':
    #训练过程,根据samples构建决策树
    keys = ['attr1','attr2']
    root = create_tree(samples,keys)
    print 'The decision tree is :',root

    #使用过程,利用决策树对于test_data判断所属的类别
    test_data1 = [0,0]
    print test_data1,'Judge result is :',judge(root,keys,test_data1)
    test_data2 = [1,1]
    print test_data2,'Judge result is :',judge(root,keys,test_data2)

算法输出:

The decision tree is : {'attr1': {0: 'no', 1: {'attr2': {0: 'no', 1: 'yes'}}}}
[0, 0] Judge result is : no
[1, 1] Judge result is : yes


  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值