2、决策树的构造
2.1、计算数据集的香农熵
from math import log
def calcShannonEnt(dataSet): # 计算数据集的香农熵
numEntries = len(dataSet) # 数据集中实例的总数
labelCounts = {} # 字典键值记录当前类别出现的次数
for featVec in dataSet:
currentLabel = featVec[-1] # 当前标签为字典
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries # 计算每个类别的概率
shannonEnt -= prob * log(prob, 2) # 熵值
return shannonEnt
2.2、创建数据集
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
交互测试
import decision_tree
myDat, lables = decision_tree.createDataSet()
myDat
Out[10]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
decision_tree.calcShannonEnt(myDat)
Out[13]: 0.9709505944546686
数据集越混乱,熵越高,测试如下
myDat
Out[33]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
myDat[0][-1] = 'maybe'
decision_tree.calcShannonEnt(myDat)
Out[35]: 1.3709505944546687
2.3、按照给定特征划分数据集
def splitDataSet(dataSet, axis, value): # 按照给定特征分割数据集(待划分数据集,划分数据集的特征,需要返回的特征的值)
retDataSet = [] # 新的list对象
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] # 接收axis之前的特征
reducedFeatVec.extend(featVec[axis+1:]) # 将axis之后的特征进行扩展
retDataSet.append(reducedFeatVec)
return retDataSet # 得到依据该特征划分后的数据子集
交互测试
importlib.reload(decision_tree)
Out[36]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
myDat, lables = decision_tree.createDataSet()
myDat
Out[38]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
decision_tree.splitDataSet(myDat, 0, 1)
Out[39]: [[1, 'yes'], [1, 'yes'], [0, 'no']]
decision_tree.splitDataSet(myDat, 0, 0)
Out[40]: [[1, 'no'], [1, 'no']]
2.4、选择最好的数据集划分方式
def chooseBeatFeatureToSplit(dataSet): # 选取最好的数据集的划分方式
numFeatures = len(dataSet[0]) - 1 # 特征数量
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的香农熵
bestInfoGain = 0.0 # 初始化最优的信息增益
bestFeature = -1
for i in range(numFeatures): # 依据特征循环
featList = [example[i] for example in dataSet] # 第i个特征的所有内容
uniqueVals = set(featList) # 去重
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value) # 数据子集
prob = len(subDataSet)/float(len(dataSet)) # 概率
newEntropy += prob * calcShannonEnt(subDataSet) # 计算数据子集中每个特征的香农熵
infoGain = baseEntropy - newEntropy # 信息增益
if infoGain > bestInfoGain: # 比较当前信息增益和上一次的
bestInfoGain = infoGain
bestFeature = i
return bestFeature
交互测试
importlib.reload(decision_tree)
Out[48]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
myDat, lables = decision_tree.createDataSet()
decision_tree.chooseBeatFeatureToSplit(myDat)
Out[50]: 0
2.5、出现次数最多的类
def majorityCnt(classList): # 分类次数最多的一类,参数(分类名称的列表)
classCount = {} # 空字典
for vote in classList: # 创建键值为classList中唯一的数据字典
if vote in classCount.keys(): # 统计每个类别的出现的频率
classCount[vote] += 0
classCount[vote] += 1
sortedclassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 根据频率排序
return sortedclassCount[0][0] # 返回出现次数最多的分类名称
2.6、创建树
def createTree(dataSet, labels): # 创建树,参数(数据集和包含所有特征的标签列表)
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList): # 类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: # 遍历完所有特征返回出现次数最多的
return majorityCnt(classList)
bestFeat = chooseBeatFeatureToSplit(dataSet) # 选取最好的特征
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}} # 字典myTree存储了树的所有信息
del(labels[bestFeat]) # 删除已选的的特征
featValues = [example[bestFeat] for example in dataSet] # 得到所有属性值
uniqueVals = set(featValues) # 去重
for value in uniqueVals:
subLabels = labels[:] # 特征子集
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
交互测试:
importlib.reload(decision_tree)
myDat, labels = decision_tree.createDataSet()
myTree = decision_tree.createTree(myDat, labels)
myTree
Out[56]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
3、使用matplotlib注解绘制树形图
3.1、初步画图,设置基本内容结构
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定义文本框和箭头格式
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
plt.rcParams['font.sans-serif'] = ['SimHei'] # 正常显示中文
def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 定义了实际的绘图功能
createplot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType,
arrowprops=arrow_args) # 文本内容,箭头尖端,文本位置,coords分别指定坐标系
def createplot(): # 绘制总体的图
fig = plt.figure(1, facecolor='white')
fig.clf()
createplot.ax1 = plt.subplot(111, frameon=False)
plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
初步测试
matplotlib指南https://blog.csdn.net/wizardforcel/article/details/54782628
3.2、获取叶节点的数目和树的深度并测试
def getNumLeafs(myTree): # 获取叶节点的数目
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据是否为字典
numLeafs += getNumLeafs(secondDict[key]) # 如果是依据子树继续数
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict': # 测试节点的数据是否为字典
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i): # 输出预先存储的树的信息,用于测试
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
交互测试
importlib.reload(treePlotter)
Out[55]: <module 'treePlotter' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\treePlotter.py'>
myTree = treePlotter.retrieveTree(1)
myTree
Out[57]:
{'no surfacing': {0: 'no',
1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
myTree = treePlotter.retrieveTree(0)
treePlotter.getNumLeafs(myTree)
Out[59]: 3
treePlotter.getTreeDepth(myTree)
Out[60]: 2
3.3组合方法绘制完整的树
def retrieveTree(i): # 输出预先存储的树的信息,用于测试
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] # 填充信息的x横坐标
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] # 填充信息的y纵坐标
createplot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 计算树叶节点的数量
depth = getTreeDepth(myTree) # 计算树深
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # 文本位置
plotMidText(cntrPt, parentPt, nodeTxt) # 填充父子节点之间的信息
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 实际绘图
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 减少y的偏移
for key in secondDict.keys(): # 遍历整棵树
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createplot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf() # 清除去上图数据
axprops = dict(xticks=[], yticks=[])
createplot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree)) # 树宽
plotTree.totalD = float(getTreeDepth(inTree)) # 树深
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
交互测试:
importlib.reload(treePlotter)
Out[19]: <module 'treePlotter' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\treePlotter.py'>
myTree = treePlotter.retrieveTree(0)
treePlotter.createplot(myTree)
4、算法测试
def classify(inputTree, featLables, testVec):
firstStr = list(inputTree.keys())[0]
secondList = inputTree[firstStr]
featIndex = featLables.index(firstStr) # 将标签字符串转换为索引
for key in secondList.keys():
if testVec[featIndex] == key:
if type(secondList[key]).__name__ == 'dict':
classLabel = classify(secondList[key], featLables, testVec)
else:
classLabel = secondList[key]
return classLabel
importlib.reload(decision_tree)
Out[16]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
myDat, labels = decision_tree.createDataSet()
myTree = decision_tree.retrieveTree(0)
decision_tree.classify(myTree, labels, [1, 0])
Out[19]: 'no'
decision_tree.classify(myTree, labels, [1, 1])
Out[20]: 'yes'
4.1、决策树的存储
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
importlib.reload(decision_tree)
Out[30]: <module 'decision_tree' from 'C:\\Users\\xuning\\PycharmProjects\\machine learning\\decision_tree\\decision_tree.py'>
decision_tree.grabTree('classifierStorage.txt')
Out[31]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
4.2、使用决策树预测隐形眼镜的类型
fr = open('lenses.txt')
fr
Out[36]: <_io.TextIOWrapper name='lenses.txt' mode='r' encoding='cp936'>
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = decision_tree.createTree(lenses, lensesLabels)
lensesTree
import treePlotter
Backend Qt5Agg is interactive backend. Turning interactive mode on.
treePlotter.createplot(lensesTree)