"""
绘制树图形
"""
import pandas as pd
#from task8 import createTree
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
def plotNode(text, centerPt, parentPt, nodeType):
"""
绘制注解,带箭头
annotate 函数用来绘制注解
parentPt: 父节点位置
centerPt: 被指向的位置
nodeType: 节点类型
"""
createPlot.ax1.annotate(text, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops= {'arrowstyle': '<-'} )
# 内部节点文本框样式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# 叶节点文本框样式
leafNode = dict(boxstyle="round4", fc="0.8")
def getNumLeafs(myTree):
"""
获取叶节点的数目
确定横轴x的长度
"""
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict:
if isinstance(secondDict[key], dict):
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
"""
获取树的深度
确定纵轴y的长度
"""
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict:
if isinstance(secondDict[key], dict):
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def createPlot(tree):
"""
创建画布
"""
fig = plt.figure(1, facecolor='gray')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(tree))
plotTree.totalD = float(getTreeDepth(tree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(tree, (0.5,1.0), '')
plt.show()
def plotMidText(cntrPt, parentPt, txtString):
"""
连线添加文字
"""
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=0)
def plotTree(myTree, parentPt, text):
"""
绘制决策树
"""
numLeafs = getNumLeafs(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, text)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict:
if isinstance(secondDict[key], dict):
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def calcShannonEnt(dataSet):
""" 计算信息熵
"""
# 获取最后一列的数据
labels = dataSet[:,-1]
# 统计所有类别对应出现的次数
labelCounts = Counter(labels)
# 数据已准备好,计算熵
shannonEnt = 0.0
dataLen = len(dataSet)
for key in labelCounts:
pro = labelCounts[key] / dataLen
shannonEnt -= pro * np.log2(pro)
return shannonEnt
def chooseFeature(dataSet):
"""
选择最优属性
gain = baseEntropy - newEntropy
"""
baseEntropy = calcShannonEnt(dataSet) # 整个数据集的信息熵
bestInfoGain = 0.0
bestFeature = -1
# 遍历所有属性
for i in range(len(dataSet[0]) -1):
splitDict = Counter(dataSet[:, i])
newEntropy = 0.0
for v in splitDict:
subDataSet = dataSet[dataSet[:, i]==v]
prob = splitDict[v]/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
gain = baseEntropy - newEntropy
if gain > bestInfoGain:
bestInfoGain = gain
bestFeature = i
return bestFeature
def createTree(dataSet, feature_labels):
"""
生成决策树
返回字典树
dataSet: 数据集
feature_labels: 属性标签
"""
labels = dataSet[:, -1]
# 数据集样本类别相同
if len(set(labels)) == 1:
return labels[0]
# 属性值为空或者唯一属性值相同,返回样本数最多的类别
if len(dataSet[0]) == 1 or (len(dataSet[0]) == 2 and len(set(dataSet[:, 0])) == 1):
resDict = dict(Counter(labels))
sortedClassCount = sorted(resDict.items(), key=lambda item: item[1], reverse=True)
return sortedClassCount[0][0]
# 选择最优属性
bestFeat = chooseFeature(dataSet)
bestFeatLabel = feature_labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(feature_labels[bestFeat])
# 对选择属性进行划分
for v in Counter(dataSet[:, bestFeat]):
# 划分后的子集不应该包含我们选择属性对应的列
subDataSet = np.delete(dataSet[dataSet[:, bestFeat]==v], bestFeat, axis=1)
subLabels = feature_labels[:]
# 生成子树
myTree[bestFeatLabel][v] = createTree(subDataSet,subLabels)
return myTree
if __name__ == '__main__':
# 读取数据集,这里忽略 ids 及 连续属性列
dataset = pd.read_csv("watermelon_3a.csv", usecols=['color', 'root', 'knocks', 'texture', 'navel', 'touch', 'label'])
feature_labels = list(dataset.columns)
dataset = dataset.values
res = createTree(dataset, feature_labels)
print(res)
createPlot(res)
task9
最新推荐文章于 2024-03-29 22:32:46 发布