构建决策树:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/4/5 20:21
# @Author : HJH
# @Site :
# @File : decision_tree_scatter.py
# @Software: PyCharm
from math import log
import operator
import treePlotter
import pickle
import os
import numpy as np
from sklearn.datasets import load_iris
def loadDataSet():
with open('./lenses.txt') as f:
lenses=[inst.strip().split('\t') for inst in f.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
return lenses,lensesLabels
# digits=load_iris()
# data=digits.data
# temp_data=np.array(data)
# target=digits.target
# temp_target = np.array(target).reshape(150,1)
# temp_dataSet=np.column_stack((temp_data,temp_target))
# dataSet=temp_dataSet.tolist()
# labels=digits.feature_names
# return dataSet,labels
#计算数据集的熵
def calcShannonEnt(dataSet):
m = len(dataSet)
labelCounts = {}
#为所有可能的分类创建字典
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
# print(key)
# 迭代的是字典中的键
prob = float(labelCounts[key])/m
shannonEnt -= prob * log(prob,2)
return shannonEnt
#划分数据集(参数:带划分数据集,需要划分数据集中的哪一列特征,需要返回哪一个特征值)
def splitDataSet(dataSet, axis, value):#splitDataSet(dataset, 1, 1)
#为了不修改原数据集,创建新列表
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
# print(reducedFeatVec)
#>>[1]
reducedFeatVec.extend(featVec[axis+1:])
# print(reducedFeatVec)
# >>[1, 'yes']
retDataSet.append(reducedFeatVec)
# print(retDataSet)
# >>[[1, 'yes']]
return retDataSet
#选择最好的特征集划分
def chooseBestFeatureToSplit(dataSet):
# print(dataSet)
#最后一列最为label
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
#迭代所有的特征
for i in range(numFeatures):
#创建唯一的分类标签列表uniqueVals
featList = [example[i] for example in dataSet]
# print(featList)
uniqueVals = set(featList)
newEntropy = 0.0
#计算每种划分的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
#计算信息增益
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
# print(infoGain,i,bestInfoGain)
return bestFeature
#出现次数最多的类别
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
#创建决策树
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
#如果类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
#使用完所有的特征仍不能将数据集划分为仅包含唯一类别的分组
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
#删除标签
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
# print(subLabels)
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
#决策树的分类函数
def classify(inputTree, featLabels, testVec):
#第一个分类特征的键
firstStr = list(inputTree.keys())[0]
# 第一个分类特征的值,即第二个字典
secondDict = inputTree[firstStr]
# print(secondDict)
#将标签转换为索引,index方法查找当前列表中第一个匹配firstStr的索引
featIndex = featLabels.index(firstStr)
# print(featIndex)
#根据索引获取测试集中对应特征的值
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else:
classLabel = valueOfFeat
return classLabel
#用pickle序列化存储决策树
def storeTree(inputTree, filename):
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
with open(filename,'rb') as fr:
myTree=pickle.load(fr)
return myTree
if __name__=='__main__':
if os.path.exists('./strotree.txt'):
myTree=grabTree('./strotree.txt')
else:
dataset, labels = loadDataSet()
myTree = createTree(dataset, labels)
storeTree(myTree,'./strotree.txt')
dataset, labels = loadDataSet()
print(classify(myTree,labels,["young","myope","no","normal"]))
treePlotter.createPlot(myTree)
可视化决策树:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/4/6 20:28
# @Author : HJH
# @Site :
# @File : treePlotter.py
# @Software: PyCharm
import matplotlib.pyplot as plt
# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 获取叶子节点数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
#如果不是字典类型,就是叶子结点
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
# 获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
# 如果不是字典类型,就是叶子结点
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#在父子节点中填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
#画树的主要方法
def plotTree(myTree, parentPt, nodeTxt):
#树的宽度和高度,用来判断树的位置
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr =list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
#标记子节点属性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
#减小y偏移
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
# 如果不是字典类型,就是叶子结点
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key)) # recursion
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# if you do get a dictonary you know it's a tree, and the first element will be another dict
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks
# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#全局变量
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW;
plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False) # ticks for demo puropses
# plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
lenses.txt:
young myope no reduced no lenses
young myope no normal soft
young myope yes reduced no lenses
young myope yes normal hard
young hyper no reduced no lenses
young hyper no normal soft
young hyper yes reduced no lenses
young hyper yes normal hard
pre myope no reduced no lenses
pre myope no normal soft
pre myope yes reduced no lenses
pre myope yes normal hard
pre hyper no reduced no lenses
pre hyper no normal soft
pre hyper yes reduced no lenses
pre hyper yes normal no lenses
presbyopic myope no reduced no lenses
presbyopic myope no normal no lenses
presbyopic myope yes reduced no lenses
presbyopic myope yes normal hard
presbyopic hyper no reduced no lenses
presbyopic hyper no normal soft
presbyopic hyper yes reduced no lenses
presbyopic hyper yes normal no lenses