决策树
什么是决策树
简而言之,就是通过一步步的决策,来对某些事物分类(说着说最终做出什么选择)。其特点是树状结构,能够更快的分类,如下图:
如何分类
大家应该都学过哈夫曼树,他的思想是用的最多的放在最靠近根节点的位置。决策树也类似,但是其分类依据是根据信息熵来分类,每次划分后,要求划分的两部分的信息熵的和最低。下面介绍下信息熵的概念
信息熵
信息熵可以看作是对于某个消息或者事件的平均预测的困难程度。
想象一下,如果你已经知道了一个事件一定会发生,那么这个事件就不会给你带来任何新的信息,因为你已经知道结果。相反,如果一个事件有很多可能的结果,你就无法准确地预测具体会发生什么,那么这个事件就会带来更多的信息。
信息熵的计算方法是基于事件发生的概率。如果一个事件有很高的概率发生,那么它所携带的信息量就会很低;而如果一个事件的发生概率很低,那么它所携带的信息量就会很高。信息熵的值越高,意味着事件的不确定性越大,需要更多的信息来描述。
如果用公式来表示的话,那就是:
H
=
−
∑
i
=
1
n
p
(
x
i
)
l
o
g
p
(
x
i
)
H = -\sum_{i=1}^np(x_i)log p(x_i)
H=−i=1∑np(xi)logp(xi)
其中
p
(
x
i
)
p(x_i)
p(xi)表示随时事件
x
i
x_i
xi发生的概率,而
−
l
o
g
p
(
x
i
)
-logp(x_i)
−logp(xi)则代表该事件的信息量
据此,我们就可以得到计算信息熵的方法了:
# 计算数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: # 为所有可能分类创建字典
currentLabel = featVec[-1] # 取数据集的标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 # 分类标签值初始化
labelCounts[currentLabel] += 1 # 给标签赋值
shannonEnt = 0.0 # 熵初始化
for key in labelCounts:
prob = float(labelCounts[key])/numEntries # 求得每个标签的概率 # L(Xi) = -log2P(Xi)
shannonEnt -= prob * log(prob, 2) # 以2为底求对数 # H = - Σi=1 n P(Xi)*log2P(Xi)
return shannonEnt
分类的实现
有了前面的基础知识,我们可以分类了,假设每次只取一种特征进行二分类,那么我们只需要计算每一种特征值分类后的熵,取最小的那个进行分类就可以了。同时由于该特征已经用过了,所以之后的样本种可以将该特征去除来进行计算。代码如下:
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 计算特征的数目
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的原始香农熵 用于与划分完的数据集的香农熵进行比较
bestInfoGain = 0.0 # 最佳信息增益初始化
bestFeature = -1 # 最佳划分特征初始化 TheBestFeatureToSplit
for i in range(numFeatures): # 遍历所有的特征
featList = [example[i] for example in dataSet] # 暂存每一个样本的第i个特征
uniqueVals = set(featList) # 特征去重
newEntropy = 0.0 # 划分后信息熵初始化
for value in uniqueVals: # 遍历去重后的特征 分别计算每一个划分后的香农熵
subDataSet = splitDataSet(dataSet, i, value) # 划分
prob = len(subDataSet)/float(len(dataSet)) # 算概率
newEntropy += prob * calcShannonEnt(subDataSet) # 算熵
infoGain = baseEntropy - newEntropy # 计算信息增益
if (infoGain > bestInfoGain): # 比较划分后的数据集的信息增益是否大于0 大于0 证明划分的有效
bestInfoGain = infoGain # 储存最佳信息增益值
bestFeature = i # 储存最佳特征值索引
return bestFeature # 返回最佳特征值索引
构建决策树
现在我们已经知道如何划分了,但是我们还不知道什么时候停止,虽然此处我们选择所有样本均属于同一类的时候停止,但是有的时候停止条件未必如此,最后的叶子节点中可能包含其他类,因此我们构造两个函数如下:
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] # 结束划分 如果只有一种分类属性 属性标签重复
if len(dataSet[0]) == 1: # 结束划分 如果没有更多的特征了 都为同一类属性标签了
return majorityCnt(classList) # 计数排序 取最大数特征
bestFeat = chooseBestFeatureToSplit(dataSet) # 获取最优特征索引
bestFeatLabel = labels[bestFeat] # 获取最优特征属性标签
myTree = {bestFeatLabel: {}} # 决策树初始化 嵌套字典
del(labels[bestFeat]) # 删除已经使用的特征标签 这时应只剩下有脚蹼特征了
featValues = [example[bestFeat] for example in dataSet] # 取出数据集所有最优属性值
uniqueVals = set(featValues) # 去重
# 开始构建决策树
for value in uniqueVals:
subLabels = labels[:] # 得到剩下的所有特征标签 作为我们的子节点可用
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
实验
此时我们已经可以完整的构造一个决策树了(画图的代码参见下方完整代码),我们可以采用书中附赠的数据集来进行测试,实验结果如图:
完整代码如下:
from math import log
import operator
import matplotlib.pyplot as plt
# 计算数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet: # 为所有可能分类创建字典
currentLabel = featVec[-1] # 取数据集的标签
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0 # 分类标签值初始化
labelCounts[currentLabel] += 1 # 给标签赋值
shannonEnt = 0.0 # 熵初始化
for key in labelCounts:
prob = float(labelCounts[key])/numEntries # 求得每个标签的概率 # L(Xi) = -log2P(Xi)
shannonEnt -= prob * log(prob, 2) # 以2为底求对数 # H = - Σi=1 n P(Xi)*log2P(Xi)
return shannonEnt
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 计算特征的数目
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的原始香农熵 用于与划分完的数据集的香农熵进行比较
bestInfoGain = 0.0 # 最佳信息增益初始化
bestFeature = -1 # 最佳划分特征初始化 TheBestFeatureToSplit
for i in range(numFeatures): # 遍历所有的特征
featList = [example[i] for example in dataSet] # 暂存每一个样本的第i个特征
uniqueVals = set(featList) # 特征去重
newEntropy = 0.0 # 划分后信息熵初始化
for value in uniqueVals: # 遍历去重后的特征 分别计算每一个划分后的香农熵
subDataSet = splitDataSet(dataSet, i, value) # 划分
prob = len(subDataSet)/float(len(dataSet)) # 算概率
newEntropy += prob * calcShannonEnt(subDataSet) # 算熵
infoGain = baseEntropy - newEntropy # 计算信息增益
if (infoGain > bestInfoGain): # 比较划分后的数据集的信息增益是否大于0 大于0 证明划分的有效
bestInfoGain = infoGain # 储存最佳信息增益值
bestFeature = i # 储存最佳特征值索引
return bestFeature # 返回最佳特征值索引
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0] # 结束划分 如果只有一种分类属性 属性标签重复
if len(dataSet[0]) == 1: # 结束划分 如果没有更多的特征了 都为同一类属性标签了
return majorityCnt(classList) # 计数排序 取最大数特征
bestFeat = chooseBestFeatureToSplit(dataSet) # 获取最优特征索引
bestFeatLabel = labels[bestFeat] # 获取最优特征属性标签
myTree = {bestFeatLabel: {}} # 决策树初始化 嵌套字典
del(labels[bestFeat]) # 删除已经使用的特征标签 这时应只剩下有脚蹼特征了
featValues = [example[bestFeat] for example in dataSet] # 取出数据集所有最优属性值
uniqueVals = set(featValues) # 去重
# 开始构建决策树
for value in uniqueVals:
subLabels = labels[:] # 得到剩下的所有特征标签 作为我们的子节点可用
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
# 画图用的代码
def getNumLeafs(myTree):
numLeafs = 0
temp_keys = list(myTree.keys())
firstStr = temp_keys[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = next(iter(myTree))
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else: # 叶子节点
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
if __name__ == "__main__":
with open("Ch03/lenses.txt", "rb") as fr:
lenses = [inst.decode().strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', "tearRate"]
lensesTree = createTree(lenses, lensesLabels)
createPlot(lensesTree)
小结
还算简单的创建决策树的一章,本章的书有很多内容集中在如何画图上,由于笔者感觉这一部分相对而言不如原理重要,因此没有进行详细的解析。总的来说,这一张最重要的思想应该就是如何分类以及树这一结构。