from math import log
import csv
import matplotlib.pyplot as plt
import operator
import numpy as np
import pandas as pd
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')
def pltNode(txt, centerPt, parentPt, type):
createPlt.ax1.annotate(txt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va='center', ha='center', bbox=type, arrowprops=arrow_args)
def createPlt():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlt.ax1 = plt.subplot(111, frameon=False)
pltNode('a decision node ', (0.9, 0.1), (0.1, 0.5), decisionNode)
pltNode('a leaf node ', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def getNumLeafs(myTree):
n = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for i in secondDict.keys():
if type(secondDict[i]).__name__ == 'dict':
n += getNumLeafs(secondDict[i])
else:
n += 1
return n
def getDepth(myTree):
n = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for i in secondDict.keys():
if type(secondDict[i]).__name__ == 'dict':
d = 1 + getDepth(secondDict[i])
else:
d = 1
print(d)
if d > n:
n = d
return n
# 计算香农熵
def calShannonEnt(dataSet):
num = len(dataSet)
labelCounts = {}
for featVec in dataSet:
# print('featVec = ', featVec)
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
# print('labelCounts = ', labelCounts)
shannonEnt = 0.0
for key in labelCounts:
prob = labelCounts[key] / num
shannonEnt -= prob * log(prob, 2)
return shannonEnt
# 划分数据集,axis 第几列,value 该列的特征值
def splitDataSet(dateSet, axis, value):
retDateSet = []
for featVec in dateSet:
# print('split featVec = ', featVec)
# 按列划分,每个数据如果该列的特征值等于 value ,则把该数据添加到数据子集里,但不包含该列
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
# print('reduceFeatVec = ', reduceFeatVec)
reduceFeatVec.extend(featVec[axis + 1:])
# print('变化完 = ', reduceFeatVec)
retDateSet.append(reduceFeatVec)
# print('retDateSet = ', retDateSet)
return retDateSet
# 创建数据集
def createDataSet():
dateSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dateSet, labels
# 选择最好的划分方式
def bestSplit(dataSet):
n = len(dataSet[0]) - 1 # 特征的个数
baseEnt = calShannonEnt(dataSet)
bestGain = 0.0
bestFeature = -1
# 按列循环判断
for i in range(n):
# print('i = ', i)
feat = [a[i] for a in dataSet] # 列表生成器,该列的所有值
uniqueVals = set(feat) # 特征值的集合
newEnt = 0.0
# 按每个特征值都划分一遍子集
for value in uniqueVals:
# print('value = ', value)
subDataSet = splitDataSet(dataSet, i, value)
# print('subDataSet = ', subDataSet)
# 计算信息熵
prob = len(subDataSet) / len(dataSet)
newEnt += prob * calShannonEnt(subDataSet)
# print('newEnt = ', newEnt)
# print('第 ', i, ' 个特征的 ', 'newEnt = ', newEnt)
gain = baseEnt - newEnt # 计算增益
if gain > bestGain:
bestGain = gain
bestFeature = i
return bestFeature
def vote(classList):
classCount = {}
for i in classList:
if i not in classCount.keys():
classCount[i] = 0
classCount[i] += 1
sortClass = sorted(classCount.items(), key=lambda x: x[1], reverse=True)
return sortClass[0][1]
def createTree(dataSet, labels):
classList = [a[-1] for a in dataSet]
if classList.count(classList[0]) == len(classList):
# print('都是同类')
return classList[0]
if len(dataSet[0]) == 1:
# print('特征用完')
return vote(classList)
bestFeat = bestSplit(dataSet)
bestLabel = labels[bestFeat]
myTree = {bestLabel: {}}
del (labels[bestFeat])
feat = [e[bestFeat] for e in dataSet]
uniqueValue = set(feat)
for v in uniqueValue:
# print(v)
subLabels = labels[:]
# print(subLabels)
myTree[bestLabel][v] = createTree(splitDataSet(dataSet, bestFeat, v), subLabels)
# print(myTree)
return myTree
if __name__ == '__main__':
myData, label = createDataSet()
# shannonEnt = calShannonEnt(myData)
# bestFeat = bestSplit(myData)
myTree = createTree(myData, label)
print(myTree)