决策树
目录
1.计算香农熵:
from math import log """ 计算数据的香农熵 """ def calcshannonEnt(dataSet): # 计算数据总数 numEntries = len(dataSet) labelCounts = {} # 创建所有可能分类 for featVec in dataSet: currentlabel = featVec[-1] if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 labelCounts[currentlabel] += 1 # 计算香农熵 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEnt def createdataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return dataSet, labels # 测试计算香农熵 myDat, labels = createdataSet() print(myDat) print(labels) print(calcshannonEnt(myDat)) print("------------------------------------------------------------") # 混合数据越多,熵越高 myDat[0][-1] = 'maybe' myDat[2][-1] = 'possible' myDat[3][-1] = 'certain' print(myDat) print(calcshannonEnt(myDat))
2.按照给定特征划分数据集
""" 按照给定特征划分数据集 """ # 函数三个参数分别代表:待划分的数据集,划分数据集的特征,特征的返回值 def splitDataSet(dataSet, axis, value): reDataSet = [] for featVec in dataSet: if featVec[axis] == value: reduceFeatVec = featVec[: axis] reduceFeatVec.extend(featVec[axis+1:]) reDataSet.append(reduceFeatVec) return reDataSet myDat,labels = createdataSet() print(myDat) print("================================================") print(splitDataSet(myDat, 0, 1)) print(splitDataSet(myDat, 0, 0))
3.寻找最好的数据集划分方式
""" 遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的划分方式 """ def chooseBestFeatureToSplit(dataSet): # 求第一行有多少列的Feature numFeatures = len(dataSet[0]) - 1 # 计算没有经过划分的数据的香农熵 baseEntroy = calcshannonEnt(dataSet) # 最优的信息增益,最优的Feature编号 bestInfoGain = 0.0 bestFeature = -1 for i in range(numFeatures): ## 创建唯一的分类标签列表,获取第i个的所有特征(信息元素纵列) # 将dataSet 中的数据先按行依次放入example中,然后取得example[i]元素,放入列表featList featList = [example[i] for example in dataSet] # 使用set集,排除featList中重复的标签,得到唯一分类的集合 uniqueVals = set(featList) newEntropy = 0.0 # 遍历档次uniqueVals中所有的标签value for value in uniqueVals: # 对第i个数据划分数据集,返回所有包含i的数据(去掉第i个特征) subDataSet = splitDataSet(dataSet, i, value) # 计算包含i的数据占总数据的百分比 prob = len(subDataSet) / float(len(dataSet)) # 计算新的香农熵,不断进行迭代,这个计算过程仅在包含指定特征标签子集中进行 newEntropy += prob * calcshannonEnt(subDataSet) # 计算信息增益 infoGain = baseEntroy - newEntropy if (infoGain > bestInfoGain): # 更新信息增益 bestInfoGain = infoGain # 确定最优的增益的特征索引 bestFeature = i # 返回最优增益的索引 return bestFeature myDat, labels = createdataSet() print(chooseBestFeatureToSplit(myDat)) print(myDat)
4.多数表决处理非唯一类别分组
""" 遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组,采用多数表决法 """ import operator def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
5.创建树
""" 创建树的函数代码 """ def createTree(dataSet, labels): # 返回当前数据集下标签列所有值 classList = [example[-1] for example in dataSet] # 当类别完全相同时则停止继续划分,直接返回该类的标签(决策树构造完成) if classList.count(classList[0]) == len(classList): return classList[0] # 遍历完所有特征时,仍然不能将数据集划分成仅包含唯一类别的分组,返回出现次数最多的类别标签作为返回值 if len(dataSet[0]) == 1: return majorityCnt(classList) # 获取最好分类特征索引 bestFeat = chooseBestFeatureToSplit(dataSet) # 获取该特征的名字 bestFeatLabel = labels[bestFeat] # 这里直接使用字典变量来存储树信息,用于绘制树形图 myTree = {bestFeatLabel: {}} # 删除已经在选取的特征 del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] # 复制所有的标签 # 递归调用自身 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree myDat, labels = createdataSet() myTree = createTree(myDat, labels) print(myTree)
6.绘制注解树节点
""" 使用文本绘制注解树节点 """ import matplotlib.pyplot as plt # 决策点的属性,boxstyle是文本框类型,sawtooth是锯齿形,fc是边框粗细 decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策树叶子节点的属性 leafNode = dict(boxstyle="round4", fc="0.8") # 箭头的属性 arrow_args = dict(arrowstyle="<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args) # nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的点,xy是箭头尖的坐标,xytext设置注释内容显示的中心位置 # xycoords和textcoords是坐标xy 与 xytext 的说明(按轴坐标), 若textcoords=None, 则默认textcoords与xycoords相同,若都未设置,则默认为data # va/ha设置节点框中文字的位置,va为纵向取值为(u'top',u'bottom',u'center',u'baseline'),ha为纵向取值为(u'center',u'right',u'left') def createPlot(): # 创建一个画布,背景为白色 fig = plt.figure(1, facecolor='white') fig.clf() # 画布清空 # ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以 createPlot.ax1 = plt.subplot(111, frameon=True) # frameon表示是否绘制坐标轴矩形 plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) plotNode('a leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode) plt.show() createPlot()
7.获取叶节点的数目和树的层数
""" 获取叶节点的数目和树的层数 """ def getNumLeafs(myTree): # 初始化节点数 numLeafs = 0 firstside = list(myTree.keys()) # 找到输入的第一个元素,第一个关键词划分数据类别的标签 firstStr = firstside[0] secondDict = myTree[firstStr] # 测试数据是否为字典形式 for key in secondDict.keys(): # type判断子节点是否为字典类型 if type(secondDict[key]).__name__ == 'dict': # 若子节点也是字典,则也是判断节点,需要递归获取num numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 # 返回整棵树的节点数 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstside = list(myTree.keys()) firstStr = firstside[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth """ 输出预先存储的树信息,避免每次测试都需要重新创建树 """ def retrieveTree(i): listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {'head': {0: 'no', 1: 'yes'}}}}} ] return listOfTrees[i] myTree = retrieveTree(0) print(getTreeDepth(myTree)) print(getNumLeafs(myTree)) print(retrieveTree(0)) print(retrieveTree(1))
8.绘制
def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt): # 计算树的宽度 totalW numLeafs = getNumLeafs(myTree) # 计算树的高度 totalD depth = getTreeDepth(myTree) firstside = list(myTree.keys()) firstStr = firstside[0] # 找到输入的第一个元素 # 按照叶子节点个数划分x轴 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # 标注节点属性 plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] # y方向上的摆放位置自下而上绘制,因此递减y值 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): # 判断是否为字典 不是则为叶子节点 if type(secondDict[key]).__name__ == 'dict': # 递归继续向下找 plotTree(secondDict[key], cntrPt, str(key)) else: # 为叶子节点 # x方向计算节点坐标 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 绘制 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) # 添加文本信息 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) # 下次重新调用时回恢复y plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD # 主函数 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() myTree = retrieveTree(0) # createPlot(myTree) myTree['no surfacing'][3] = 'maybe' print(myTree) createPlot(myTree)
9.完整代码
from math import log """ 计算数据的香农熵 """ def calcshannonEnt(dataSet): # 计算数据总数 numEntries = len(dataSet) labelCounts = {} # 创建所有可能分类 for featVec in dataSet: currentlabels = featVec[-1] if currentlabels not in labelCounts.keys(): labelCounts[currentlabels] = 0 labelCounts[currentlabels] += 1 # 计算香农熵 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEnt def createDataSet(): dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return dataSet, labels # myDat,labels = createDataSet() # # # myDat[0][-1] = 'maybe' # myDat[2][-1] = 'possible' # myDat[3][-1] = 'certain' # print(myDat) # print(labels) # shannonEnt = calcshannonEnt(myDat) # print(shannonEnt) """熵越高,则混合的数据越多""" """ 按照给定特征划分数据集 """ # 函数三个参数分别代表:待划分的数据集,划分数据集的特征, 特征的返回值 def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reduceFeatVec = featVec[:axis] reduceFeatVec.extend(featVec[axis+1:]) retDataSet.append(reduceFeatVec) return retDataSet # myDat, labels = creatDataSet() # print(myDat) # print(splitDataSet(myDat, 0, 1)) # print(splitDataSet(myDat, 0, 0)) """ 遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的划分方式 """ def chooseBestFeatureToSplit(dataSet): # 求第一行有多少列的Feature,减去1,是因为最后一列是label列 numFeatures = len(dataSet[0]) - 1 # 计算没有经过划分的数据的香农熵 baseEntropy = calcshannonEnt(dataSet) # 最优的信息增益值; 最优的Feature编号 bestInfoGain = 0.0 ; bestFeature = -1 for i in range(numFeatures): ## 创建唯一的分类标签列表,获取第i个的所有特征(信息元素纵列!) # 将dataSet 中的数据先按行依次放入example中,然后取得example中的example[i]元素,放入列表featList中 featList = [example[i] for example in dataSet] # 使用set集, 排除featList中重复的标签,得到唯一分类的集合 uniqueVals = set(featList) newEntropy = 0.0 # 遍历当次uniqueVals中所有的标签value for value in uniqueVals: # 对第i个数据划分数据集,返回所有包含i的数据(已排除第i个特征) subDataSet = splitDataSet(dataSet, i, value) # 计算包含i的数据占总数据的百分比 prob = len(subDataSet) / float(len(dataSet)) # 计算新的香农熵,不断进行迭代,这个计算过程仅在包含指定特征标签子集中进行 newEntropy += prob * calcshannonEnt(subDataSet) # 计算信息增益 infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): # 更新信息增益 bestInfoGain = infoGain # 确定最优增益的特征索引 bestFeature = i # 更新最优增益 # 返回最优增益的索引 return bestFeature # myDat, labels = creatDataSet() # print(chooseBestFeatureToSplit(myDat)) # print(myDat) """ 遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组,采用多数表决法 """ import operator def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] """ 创建树的函数代码 """ def createTree(dataSet, labels): # 返回当前数据集下标签列所有值 classList = [example[-1] for example in dataSet] # 当类别完全相同时则停止继续划分,直接返回该类的标签(决策树构造完成) if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: # 遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组dataSet return majorityCnt(classList) # 由于无法简单的返回唯一的类标签,这里就返回出现次数最多的类别作为返回值 bestFeat = chooseBestFeatureToSplit(dataSet) # 获取最好的分类特征索引 bestFeatLabel = labels[bestFeat] # 获取该特征的名字 # 这里直接使用字典变量来存储树信息,这对于绘制树形图很重要 myTree = {bestFeatLabel: {}} # 当前数据集选取最好的特征存储在bestFeat中 del(labels[bestFeat]) # 删除已经在选取的特征 featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] # 复制所有的标签 # 递归调用自身 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree # myDat, labels = creatDataSet() # myTree = createTree(myDat, labels) # print(myTree) """ 使用文本绘制注解树节点 """ import matplotlib.pyplot as plt # 决策点的属性,boxstyle是文本框类型,sawtooth是锯齿形,fc是边框粗细 decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 决策树叶子节点的属性 leafNode = dict(boxstyle="round4", fc="0.8") # 箭头的属性 arrow_args = dict(arrowstyle="<-") def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va='center', ha='center', bbox=nodeType, arrowprops=arrow_args) # nodeTxt为要显示的文本,centerPt为文本的中心点,parentPt为箭头指向文本的点,xy是箭头尖的坐标,xytext设置注释内容显示的中心位置 # xycoords和textcoords是坐标xy 与 xytext 的说明(按轴坐标), 若textcoords=None, 则默认textcoords与xycoords相同,若都未设置,则默认为data # va/ha设置节点框中文字的位置,va为纵向取值为(u'top',u'bottom',u'center',u'baseline'),ha为纵向取值为(u'center',u'right',u'left') # def createPlot(): # # 创建一个画布,背景为白色 # fig = plt.figure(1, facecolor='white') # fig.clf() # 画布清空 # # ax1是函数createPlot的一个属性,这个可以在函数里面定义也可以在函数定义后加入也可以 # createPlot.ax1 = plt.subplot(111, frameon=True) # frameon表示是否绘制坐标轴矩形 # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) # plotNode('a leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode) # plt.show() # createPlot() """ 获取叶节点的数目和树的层数 """ def getNumLeafs(myTree): # 初始化节点数 numLeafs = 0 # python3替换注释的两行代码 firstside = list(myTree.keys()) firstStr = firstside[0] # 找到输入的第一个元素,第一个关键词划分数据类别的标签 secondDict = myTree[firstStr] # firstStr = myTree.keys()[] # secondDict = myTree[firstStr] for key in secondDict.keys(): # 测试数据是否为字典形式 # type判断子节点是否为字典类型 if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) # 若子节点也为字典,则也是判断节点,需要递归获取num else: numLeafs += 1 # 返回整棵树的节点数 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstside = list(myTree.keys()) firstStr = firstside[0] secondDict = myTree[firstStr] # firstStr = myTree.keys()[0] # secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth # 输出预先存储的树信息,避免每次测试都需要重新创建树 def retrieveTree(i): listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i] # # print(retrieveTree(1)) # myTree = retrieveTree(0) # print(getNumLeafs(myTree)) # print(getTreeDepth(myTree)) # def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString) def plotTree(myTree, parentPt, nodeTxt): # 计算树的宽度 totalW numLeafs = getNumLeafs(myTree) # 计算树的高度 totalD depth = getTreeDepth(myTree) firstside = list(myTree.keys()) firstStr = firstside[0] # 找到输入的第一个元素 # 按照叶子节点个数划分x轴 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # 标注节点属性 plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] # y方向上的摆放位置自下而上绘制,因此递减y值 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): # 判断是否为字典 不是则为叶子节点 if type(secondDict[key]).__name__ == 'dict': # 递归继续向下找 plotTree(secondDict[key], cntrPt, str(key)) else: # 为叶子节点 # x方向计算节点坐标 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 绘制 plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) # 添加文本信息 plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) # 下次重新调用时回恢复y plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD # 主函数 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show() myTree = retrieveTree(0) # createPlot(myTree) myTree['no surfacing'][3] = 'maybe' print(myTree) createPlot(myTree)