# 周志华老师的书上西瓜分类问题。
# 数据为p76页表4.1,算法逻辑完全按照p74写的
# -*- coding: utf-8 -*-
# author: Xin Chen
import pandas as pd
import numpy as np
def getdata(path):
train_data = pd.read_csv(path, index_col=0, encoding='gbk')
# train_data是一个DataFrame,应该是前面有一列index(从1开始),最后一列是label,columns是所有的特征名
return train_data
def entropy(dataset):
# 计算数据集的熵
if type(dataset) == pd.Series:
return 0
labels = dataset['label']
# print dataset, labels
num = labels.groupby(labels).count()
pr = num/num.sum()
log2pr = np.log2(pr)
log2pr[np.isfinite(log2pr)==False] = 0
return -(pr*log2pr).sum()
def cal_gain(dataset, a):
# 计算当前数据集在给定属性下的信息增益
Ent = entropy(dataset)
avs = dataset[a].unique()
atd = dataset[a]
anum = atd.groupby(atd).count()
newdataset = dataset.set_index(a)
# print newdataset
Entadict = {}
for ai in avs:
Entadict[ai] = entropy(newdataset.ix[ai])
Entas = pd.Series(Entadict)
Enta = (Entas*anum).sum()/anum.sum()
return Ent-Enta
def get_optimala(dataset):
attributes = dataset.columns[:-1]
maxgain = 0.0
optimala = attributes[0]
for a in attributes:
gain = cal_gain(dataset, a)
if gain > maxgain:
maxgain = gain
optimala = a
# print 'optimala:', optimala
return optimala
def mostlabel(labels):
num = labels.groupby(labels).count()
return num[num == num.max()].index[0]
def generateTree(dataset):
global initialdataset, attrivalues
labels = dataset['label']
attributes = dataset.columns[:-1]
if len(labels) == 0:
return
# 当前数据集中全部样本属于同一类别,返回该类叶节点
if len(labels.unique()) == 1:
return labels.values[0]
# 属性集为空的,或者数据集中所有样本在当前属性集的取值全部相等,返回该数据集中样本数最多的类别
x = dataset.drop('label', axis=1)
same = (x.duplicated()==False).sum()
if (len(attributes) == 0) or (same == 0):
return mostlabel(labels)
optimala = get_optimala(dataset)
avs = attrivalues[optimala]
tree = {optimala: {}}
for ai in avs:
newdataset = dataset[dataset[optimala] == ai]
newdataset = newdataset.drop(optimala, axis=1)
if len(newdataset) == 0:
tree[optimala][ai] = mostlabel(labels)
else:
tree[optimala][ai] = generateTree(newdataset)
return tree
def store_tree(tree, filename):
# 把决策树以二进制格式写入文件
import pickle
# pickle模块: http://www.iplaypy.com/module/pickle.html
writer = open(filename, 'w')
pickle.dump(tree, writer)
writer.close()
def read_tree(filename):
# 从文件中读取决策树,返回决策树
import pickle
reader = open(filename, 'rU')
return pickle.load(reader)
def main():
path = "watermelon.csv"
global initialdataset, attrivalues
initialdataset = getdata(path)
attrivalues = {}
attributes = initialdataset.columns[:-1]
for a in attributes:
attrivalues[a] = initialdataset[a].unique()
# print attrivalues
tree = generateTree(initialdataset)
print "decision tree: ", tree
# store_tree(tree, 'tree')
# print read_tree('tree')
if __name__ == "__main__":
main()
# 数据:watermelon.csv
'''
,color,root,sound,stripes,navel,touch,label
1,青绿,蜷缩,浊响,清晰,凹陷,硬滑,1
2,乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,1
3,乌黑,蜷缩,浊响,清晰,凹陷,硬滑,1
4,青绿,蜷缩,沉闷,清晰,凹陷,硬滑,1
5,浅白,蜷缩,浊响,清晰,凹陷,硬滑,1
6,青绿,稍蜷,浊响,清晰,稍凹,软粘,1
7,乌黑,稍蜷,浊响,稍糊,稍凹,软粘,1
8,乌黑,稍蜷,浊响,清晰,稍凹,硬滑,1
9,乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,0
10,青绿,硬挺,清脆,清晰,平坦,软粘,0
11,浅白,硬挺,清脆,模糊,平坦,硬滑,0
12,浅白,蜷缩,浊响,模糊,平坦,软粘,0
13,青绿,稍蜷,浊响,稍糊,凹陷,硬滑,0
14,浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,0
15,乌黑,稍蜷,浊响,清晰,稍凹,软粘,0
16,浅白,蜷缩,浊响,模糊,平坦,硬滑,0
17,青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,0
'''
# 最后的运行结果:
decision tree: {u'stripes': {u'\u6a21\u7cca': 0, u'\u6e05\u6670': {u'root': {u'\u7a0d\u8737':
{u'color': {u'\u6d45\u767d': 0, u'\u9752\u7eff': 1, u'\u4e4c\u9ed1': {u'touch': {u'\u8f6f\u7c98': 0, u'\u786c\u6ed1': 1}}}},
u'\u786c\u633a': 0, u'\u8737\u7f29': 1}}, u'\u7a0d\u7cca': {u'touch': {u'\u8f6f\u7c98': 1, u'\u786c\u6ed1': 0}}}}
# 重新编码之后:
decision tree: {u'stripes': {u'模糊': 0, u'清晰': {u'root': {u'稍蜷':
{u'color': {u'浅白': 1, u'青绿': 1, u'乌黑': {u'touch': {u'软粘': 0, u'硬滑': 1}}}},
u'模糊': 0, u'清晰': 1}}, u'稍糊': {u'touch': {u'软粘': 1, u'硬滑': 0}}}}
# 跟p78的结果完全一样,当然这里的采用的决策是最简单的,例子也很简单。加油!!