1.算法原理
决策树是一棵树,它的每个节点都是一次决策,该节点的子树分别代表不同的决策,叶子节点表示所有数据已经属于同一类型,无法再分。
因此构造决策树只需要做一件事,找出划分当前数据集的最优特征,之后递归子树即可把决策树构造出来。
1.1 找出最优特征
那如何找出最优特征呢,可以从信息论的方向出发,在划分数据前后使用信息论量化度量信息的内容,选取信息增益最高的特征作为当前的选择。
熵定义为信息的期望值,因此需要找熵最大的划分。
这里采用ID3算法去计算熵
总结一下,我们需要对数据集的每个特征都尝试划分一次,并计算熵,最后选取使得熵最大的特征划分为当前数据集的划分。
1.2 递归子树
递归的终止条件有两个:
- 每个分支下的所有实例都具有相同分类
- 遍历完所有划分数据集的属性
情况1,所有实例都具有相同分类,这就是一个叶子节点。
情况2,可以采用多数表决的方法,将标签出现频率最高的做为此时的分类
2. 代码
import operator
from math import log
import pickle
"""
函数说明:
计算数据集的香农熵
公式:H = - ( for i in range(n): p(xi) * log(p(xi),2) )
参数:
dataSet: 数据集
返回值:
shannonEnt: 香农熵
"""
def calcShannonEnt (dataSet):
# numEntries = 数据集的行数
numEntries = len(dataSet)
# labelCounts 是记录标签出现次数的字典
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 遍历标签,计算所有标签的香农熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
"""
函数说明:
划分数据集
参数:
dataSet: 数据集
axis: 需要去掉特征的索引值
value: 需要返回特征值
返回值:
retDataSet: 返回划分后的结果集
"""
def splitDataSet(dataSet, axis, value):
retDataSet = []
# 遍历数据集
for featVec in dataSet:
if featVec[axis] == value:
# 去掉 axis 特征,
# extend 用于扩展列表,在末尾追加数据,得到的还是一个列表
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
# append 则是将整个列表增加进去,得到一个嵌套列表
retDataSet.append(reducedFeatVec)
return retDataSet
"""
函数说明:
选择最优的数据集划分方式
参数:
dataset:数据集
返回值:
bestFeature:最优特征的索引值
"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
# 计算数据集的香农熵
baseEntropy = calcShannonEnt(dataSet)
# 最优信息增益
bestInfoGain = 0.0
# 最优特征的索引值
bestFeature = -1
for i in range(numFeatures):
# 将 dataset 第i个特征存在 featList 里(列表推导)
featList = [example[i] for example in dataSet]
# 建立集合,得到列表中唯一元素值
uniqueVals = set(featList)
newEntropy = 0.0
# 遍历当前特征的所有唯一属性值,划分数据集,计算新熵值
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
"""
函数说明:
测试熵的自建数据集
参数:
无
返回值:
dataSet: 数据集
labels: 标签
"""
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
"""
函数说明:
统计classList中出现次数最多的元素(类标签)
服务于递归第两个终止条件
参数:
classList:类标签列表
返回值:
sortedClassCount[0][0]:出现次数最多的元素(类标签)
"""
def majorityCnt(classList):
classCount = {}
# 统计classList中每个元素出现的次数
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
# 根据字典的值降序排序
# operator.itemgetter(1)获取对象的第1列的值
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
# 返回classList中出现次数最多的元素
return sortedClassCount[0][0]
"""
函数说明:
创建决策树(ID3算法)
参数:
dataSet: 数据集
labels: 标签
返回值:
myTree: 决策树(字典表示)
"""
def createTree(dataSet, labels):
# 将所有标签存在 classList 里
classList = [example[-1] for example in dataSet]
# 递归终止条件:类别完全相同则停止继续划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 递归终止条件:遍历完所有特征时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 得到最优特征的索引值 bestFeat
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
# 删除已经使用过的标签
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
# python函数参数是列表类型时,参数是安装引用方式传递的
# 为保护原列表,每次都用一个新列表来代替原始列表
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
"""
函数说明:
使用决策树分类
参数:
inputTree - 已经生成的决策树
featLabels - 存储选择的最优特征标签
testVec - 测试数据列表,顺序对应最优特征标签
返回值:
classLabel - 分类结果
"""
def classify(inputTree, featLabels, testVec):
# 将标签字符串转换成索引
# 获取标签
firstStr = inputTree.keys()[0]
# 获得标签在树中所对应的字典的值
secondDict = inputTree[firstStr]
# 获取标签在原标签列表的索引值
featIndex = featLabels.index(firstStr)
for key in secondDict.key():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
def storeTree(inputTree, filename):
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
fr = open(filename)
return pickle.load(fr)
def main():
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
mTree = createTree(lenses, lensesLabels)
print(mTree)
# myDat, labels = createDataSet()
# print(createTree(myDat, labels))
# print(chooseBestFeatureToSplit(myDat))
# print(splitDataSet(myDat, 0, 1))
# print(calcShannonEnt(myDat))
# myDat[0][-1] = 'maybe'
# print(calcShannonEnt(myDat))
if __name__ == '__main__':
main()