参考自“机器学习实战”
import numpy as np
from collections import *
from math import log
def createDataset():
dataSet = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['no surfacing', 'flippers']
return dataSet,labels
def calcShannonEnt(dataset):
label_count = Counter([data[-1] for data in dataset])
print(label_count)
probs = [p[1]/len(dataset) for p in label_count.items()]
shannonEnt = sum([(-p)*log(p,2) for p in probs])
return shannonEnt
def splitDataSet(dataSet,index,value):
retDataset = [data for data in dataset for i,v in enumerate(data) if i==index and v==value]
def chooseBestFeatureToSplit(dataSet):
base_entropy = calcShannonEnt(dataset)
base_info_gain =0
best_feature = -1
for i in range(len(dataset[0]-1)):
feature_count = Counter(data[i] for data in dataset)
new_entropy = sum(feature[1]/float(len(dataset)) * calcShannonEnt(splitDataSet(dataset,i,feature[0])) \
for feature in feature_count.items() )
info_gain = base_entropy - new_entropy
print('now. {%d} feature info is {.3f}'.format(i,info_gain))
if info_gain > base_info_gain:
base_info_gain = info_gain
best_feature = i
return best_feature
def majorityCnt(classList):
major_label = Counter(classList).most_common(1)[0]
return major_label
def createTree(dataSet,labels):
classList = [e[-1] for e in dataset]
if __name__ == "__main__":
dataset,label = createDataset()
calcShannonEnt(dataset)
# d = [1, 1, 'yes']
# for i,v in enumerate(d):
# print(i,v)
然后,构建树的过程,是一个递归的过程,递归终止的条件:
- 1)遍历完所有划分数据集的属性;
- 2)某个分支下所有实例都具有相同的分类。
递归的流程:
- 寻找划分数据集的最好特征
- 划分数据集
- 创建分支节点
- for 每个分支节点
- 递归,并增加返回结果到分支节点中
- for 每个分支节点