决策树
一种依托于策略抉择而建立起来的树。
从数据产生决策树的机器学习技术叫做决策树学习。
数据形式:决策过程只有:是/否
适用数据类型:数值型和标称型
标称型:其实就是离散型数据,变量的结果只在有限目标集中取值。
信息增益
信息熵:
表示信息的混乱程度,也就是说:信息越有序,信息熵越低。
信息增益:
信息增益越大,做的东西越好——为了找划分数据集的最好特征
划分数据集的最大原则是:将无序的数据变得更加有序。
from math import log
import operator
def createDataSet():
#每一行代表不同的数据,一共有五个数据
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
#第一个1代表是否露出水面,第二个1代表是否露出脚蹼,第三个是结果yes/no(是否是鱼类)
labels = ['no surfacing', 'flippers']
#change to discrete values
return dataSet, labels#第一个是数据集,第二是描述(标签)
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: #the the number of unique elements and their occurance
currentLabel = featVec[-1]#最后一列遍历,统计
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob, 2) #log base 2#香农熵求出香农值
return shannonEnt
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
#axis列为value的数据集【该数据集需要排除axis列】
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #chop out axis used for splitting
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
#求第一行有多少列的Feature,(减去最后一列是label列)
numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels
#label的信息熵(代表整体数据集的信息(混乱程度))
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures): #iterate over all the features
#获取每一个feature的list集合
featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
#获取剔重后的集合
uniqueVals = set(featList) #get a set of unique values
#创建一个临时的信息熵
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
#value代表label(a,b,c,d),i代表列
#计算概率
prob = len(subDataSet)/float(len(dataSet))
#计算信息熵
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy
#gain【信息增益】:划分数据集前后的信息变化,获取信息熵最大的值
#划分越有序就作为那个根
if (infoGain > bestInfoGain): #compare this to the best gain so far
bestInfoGain = infoGain #if better than current best, set to best
bestFeature = i
return bestFeature #returns an integer
def majorityCnt(classList):
classCount={}#字典声明
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
#reverse=True倒序
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet] #取分类标签
if classList.count(classList[0]) == len(classList): #如果类别完全相同则停止继续划分
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有特征时返回出现次数最多的类标签
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #选择最优特征
bestFeatLabel = labels[bestFeat] #最优特征的标签
myTree = {bestFeatLabel:{}} #根据最优特征的标签生成树
del(labels[bestFeat]) #删除已经使用特征标签
featValues =[example[bestFeat] for example in dataSet] #得到训练集中所有最优特征的属性值
uniqueVals = set(featValues) #去掉重复的属性值
for value in uniqueVals: #遍历特征,创建决策树。
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
'''
inputTree:决策树模型
featLabels:Feature标签对应的名称
testVec:测试输入的数据
返回值:classlabel分类的结果,需要映射到label才能知道名称
'''
#获取tree的根节点对应的key值
firstStr = list(inputTree)[0]
#通过key得到根节点对应的value
secondDict = inputTree[firstStr]
#传入featLabels的名称,求出对应根的名称。index名称对应的坐标
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'rb')
return pickle.load(fr)
总结:
1.extend与append的区别
>>> A = ['q', 'w', 'e', 'r']
>>> A.extend(['t', 'y'])
>>> A
['q', 'w', 'e', 'r', 't', 'y']
>>>len(A)
6
>>> B = ['q', 'w', 'e', 'r']
>>> B.append(['t', 'y'])
>>> B
['q', 'w', 'e', 'r', ['t', 'y']]
>>>len(B)
5
使用文本注解绘制树节点
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc='0.8') #设置节点格式
leafNode = dict(boxstyle="round4", fc='0.8') #设置叶节点格式
arrow_args = dict(arrowstyle="<-") #定义箭头格式
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',va='center',
ha='center',bbox=nodeType,arrowprops=arrow_args) #绘制节点
def plotMidText(cntrPt, parentPt, txtString): #计算标注位置
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid,yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #获取决策树叶结点数目,决定了树的宽度
depth = getTreeDepth(myTree) #获取决策树层数
firstStr = next(iter(myTree))
cntrPt = (plotTree.xOff +(1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) #中心位置
plotMidText(cntrPt, parentPt, nodeTxt) #标注有向边属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制结点
secondDict = myTree[firstStr] #下一个字典,也就是继续绘制子结点
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
plotTree(secondDict[key],cntrPt,str(key)) #不是叶结点,递归调用继续绘制
else: #如果是叶结点,绘制叶结点,并标注有向边属性值
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt, leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1,facecolor='white') #创建fig
fig.clf() #清空fig
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴
plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目
plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0 #x偏移
plotTree(inTree,(0.5,1.0),'') #绘制决策树
plt.show() #显示绘制结果
使用pickle模块存储决策树
def storeTree(inputTree, filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
fr=open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age','prescript','astigmatic','tearRate']
lensesTree = createTree(lenses,lensesLabels)
print(lensesTree)
createPlot(lensesTree)