机器学习初学代码(四) 决策树1

# 周志华老师的书上西瓜分类问题。
# 数据为p76页表4.1,算法逻辑完全按照p74写的

# -*- coding: utf-8 -*-
# author: Xin Chen

import pandas as pd
import numpy as np


def getdata(path):
    train_data = pd.read_csv(path, index_col=0, encoding='gbk')
    # train_data是一个DataFrame,应该是前面有一列index(从1开始),最后一列是label,columns是所有的特征名
    return train_data

def entropy(dataset):
    # 计算数据集的熵
    if type(dataset) == pd.Series:
        return 0
    labels = dataset['label']
    # print dataset, labels
    num = labels.groupby(labels).count()
    pr = num/num.sum()
    log2pr = np.log2(pr)
    log2pr[np.isfinite(log2pr)==False] = 0
    return -(pr*log2pr).sum()

def cal_gain(dataset, a):
    # 计算当前数据集在给定属性下的信息增益
    Ent = entropy(dataset)
    avs = dataset[a].unique()
    atd = dataset[a]
    anum = atd.groupby(atd).count()
    newdataset = dataset.set_index(a)
    # print newdataset
    Entadict = {}
    for ai in avs:
        Entadict[ai] = entropy(newdataset.ix[ai])
    Entas = pd.Series(Entadict)
    Enta = (Entas*anum).sum()/anum.sum()
    return Ent-Enta

def get_optimala(dataset):
    attributes = dataset.columns[:-1]
    maxgain = 0.0
    optimala = attributes[0]
    for a in attributes:
        gain = cal_gain(dataset, a)
        if gain > maxgain:
            maxgain = gain
            optimala = a
    # print 'optimala:', optimala
    return optimala

def mostlabel(labels):
    num = labels.groupby(labels).count()
    return num[num == num.max()].index[0]

def generateTree(dataset):
    global initialdataset, attrivalues
    labels = dataset['label']
    attributes = dataset.columns[:-1]

    if len(labels) == 0:
        return
    # 当前数据集中全部样本属于同一类别,返回该类叶节点
    if len(labels.unique()) == 1:
        return labels.values[0]
    # 属性集为空的,或者数据集中所有样本在当前属性集的取值全部相等,返回该数据集中样本数最多的类别
    x = dataset.drop('label', axis=1)
    same = (x.duplicated()==False).sum()
    if (len(attributes) == 0) or (same == 0):
        return mostlabel(labels)
    optimala = get_optimala(dataset)
    avs = attrivalues[optimala]
    tree = {optimala: {}}
    for ai in avs:
        newdataset = dataset[dataset[optimala] == ai]
        newdataset = newdataset.drop(optimala, axis=1)
        if len(newdataset) == 0:
            tree[optimala][ai] = mostlabel(labels)
        else:
            tree[optimala][ai] = generateTree(newdataset)
    return tree

def store_tree(tree, filename):
    # 把决策树以二进制格式写入文件
    import pickle
    # pickle模块: http://www.iplaypy.com/module/pickle.html
    writer = open(filename, 'w')
    pickle.dump(tree, writer)
    writer.close()

def read_tree(filename):
    # 从文件中读取决策树,返回决策树
    import pickle
    reader = open(filename, 'rU')
    return pickle.load(reader)

def main():
    path = "watermelon.csv"
    global initialdataset, attrivalues
    initialdataset = getdata(path)
    attrivalues = {}
    attributes = initialdataset.columns[:-1]
    for a in attributes:
        attrivalues[a] = initialdataset[a].unique()
    # print attrivalues
    tree = generateTree(initialdataset)
    print "decision tree: ", tree
    # store_tree(tree, 'tree')
    # print read_tree('tree')

if __name__ == "__main__":
    main()


# 数据:watermelon.csv
'''
,color,root,sound,stripes,navel,touch,label
1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,1
2,乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,1
3,乌黑,蜷缩,浊响,清晰,凹陷,硬滑,1
4,青绿,蜷缩,沉闷,清晰,凹陷,硬滑,1
5,浅白,蜷缩,浊响,清晰,凹陷,硬滑,1
6,青绿,稍蜷,浊响,清晰,稍凹,软粘,1
7,乌黑,稍蜷,浊响,稍糊,稍凹,软粘,1
8,乌黑,稍蜷,浊响,清晰,稍凹,硬滑,1
9,乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,0
10,青绿,硬挺,清脆,清晰,平坦,软粘,0
11,浅白,硬挺,清脆,模糊,平坦,硬滑,0
12,浅白,蜷缩,浊响,模糊,平坦,软粘,0
13,青绿,稍蜷,浊响,稍糊,凹陷,硬滑,0
14,浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,0
15,乌黑,稍蜷,浊响,清晰,稍凹,软粘,0
16,浅白,蜷缩,浊响,模糊,平坦,硬滑,0
17,青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,0
'''

# 最后的运行结果:
decision tree:  {u'stripes': {u'\u6a21\u7cca': 0, u'\u6e05\u6670': {u'root': {u'\u7a0d\u8737':
{u'color': {u'\u6d45\u767d': 0, u'\u9752\u7eff': 1, u'\u4e4c\u9ed1': {u'touch': {u'\u8f6f\u7c98': 0, u'\u786c\u6ed1': 1}}}},
u'\u786c\u633a': 0, u'\u8737\u7f29': 1}}, u'\u7a0d\u7cca': {u'touch': {u'\u8f6f\u7c98': 1, u'\u786c\u6ed1': 0}}}}

# 重新编码之后:
decision tree:  {u'stripes': {u'模糊': 0, u'清晰': {u'root': {u'稍蜷':
{u'color': {u'浅白': 1, u'青绿': 1, u'乌黑': {u'touch': {u'软粘': 0, u'硬滑': 1}}}},
u'模糊': 0, u'清晰': 1}}, u'稍糊': {u'touch': {u'软粘': 1, u'硬滑': 0}}}}

# 跟p78的结果完全一样,当然这里的采用的决策是最简单的,例子也很简单。加油!!

 

转载于:https://my.oschina.net/u/3590872/blog/1377364

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值