决策树 ID3 算法原理比较简单,处理数据的过程也比较容易理解,简单的说就是以数据集的香农熵作为判据,利用树的数据结构递归对数据集进行分类。
STEP1
计算数据集的香农熵
#计算数据集的香农熵
def calcShannonEnt(dataSet):
#获取数据个数
numEntries = len(dataSet)
#创建空字典,用于存储 label:个数 的键值对
labelCounts = {}
#遍历数据集
for featVec in dataSet:
#取当前数据的 label
currentLabel = featVec[-1]
#若该 label 还未加入字典
if currentLabel not in labelCounts.keys():
#将该 label 加入字典,置值为 0
labelCounts[currentLabel] = 0
#记录 label 的数量
labelCounts[currentLabel] += 1
#根据公式,计算该数据集的香农熵
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key] / numEntries)
shannonEnt -= prob * log(prob, 2)
return shannonEnt
以一个简易数据集为例(两个特征量 + 一个 label):
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]
print(calcShannonEnt(dataSet))
处理得到的字典 labelCounts 为:
{'yes': 2, 'no': 3}
计算结果为:
0.9709505944546686
Process finished with exit code 0
对这个数据集稍作更改:
dataSet = [[1,1,'maybe'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,0,'no']]
print(calcShannonEnt(dataSet))
处理得到的字典 labelCounts 为:
{'maybe': 1, 'yes': 1, 'no': 3}
计算结果为:
1.3709505944546687
Process finished with exit code 0
STEP2
根据香农熵划分数据集
首先编写一个辅助函数
#筛选 featVec[axis] = value 的数据
#剔除用于分类的 featVec[axis] 的信息
#结果以列表形式返回
def splitDataSet(dataSet, axis, value):
#创建用于返回的空列表
retDataSet = []
#遍历输入的 dataSet
for featVec in dataSet:
#如果符合筛选条件
if featVec[axis] == value:
#复制 featVec[axis] 之前的信息
reduceFeatVec = featVec[:axis]
#复制 featVec[axis] 之后的信息
reduceFeatVec.extend(featVec[axis+1:])
#将剔除了 featVec[axis] 的数据加入要返回的列表
retDataSet.append(reduceFeatVec)
return retDataSet
举个栗子就很好理解了:
dataSet = [[1,'A','yes'],[1,'A','yes'],[1,'B','no'],[0,'A','no'],[0,'B','no']]
print(splitDataSet(dataSet,0,0))
print(splitDataSet(dataSet,0,1))
print(splitDataSet(dataSet,1,'A'))
print(splitDataSet(dataSet,1,'B'))
得到的结果为:
[['A', 'no'], ['B', 'no']]
[['A', 'yes'], ['A', 'yes'], ['B', 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 'no'], [0, 'no']]
Process finished with exit code 0
STEP3
利用前面两个函数,构建用于选择最好的数据划分方式的函数
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
#获取输入数据集的特征数量
numFeatures = len(dataSet[0]) - 1
#计算划分前数据集的香农熵 H(D)
baseEntropy = calcShannonEnt(dataSet)
#初始化最大信息增益和对应的下标
bestInfoGain = 0.0
bestFeature = -1
#遍历所有特征
for i in range(numFeatures):
#将所有数据的该特征值写入列表
featList = [example[i] for example in dataSet]
#利用 set 函数去重
uniqueVals = set(featList)
#初始化根据该特征进行分类后的香农熵
newEntropy = 0.0
#遍历该特征的所有取值
for value in uniqueVals:
#利用前面编写的函数 splitDataSet 对数据集进行划分
subDataSet = splitDataSet(dataSet, i, value)
#根据公式计算划分后的香农熵 H(D|A)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
#根据公式计算信息增益 g(D,A) = H(D) - H(D|A)
infoGain = baseEntropy - newEntropy
#选取最大信息增益
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
例如输入:
dataSet = [[1,'A','yes'],[1,'A','yes'],[1,'B','no'],[0,'A','no'],[0,'A','no']]
print(chooseBestFeatureToSplit(dataSet))
得到的结果为:
0
Process finished with exit code 0
STEP4
递归构建决策树
这里不熟悉的可以先看一下数据结构里面关于创建树的内容。
#辅助函数:多数表决
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
#使用字典结构存储树信息
def creatTree(dataSet, featText):
subFeatText = featText[:]
#将所有数据的 label 存入 classList 列表
classList = [example[-1] for example in dataSet]
#设置返回条件(已到达叶子节点)
#返回条件 1:数据集所有 label 相同 -> 直接返回该 label
if classList.count(classList[0]) == len(classList):
return classList[0]
#返回条件 2:数据集所有特征都以遍历完 -> 返回数目最多都 label(多数表决)
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#处理当前节点(未到达叶子节点)
#使用前面编写的函数获取要作为下一步划分依据的特征
bestFeat = chooseBestFeatureToSplit(dataSet)
#获取下一步划分依据的特征的文本描述
bestFeatLabel = subFeatText[bestFeat]
#初始化此节点往下的(子)树的数据结构
myTree = {bestFeatLabel:{}}
#删除即将使用的特征
del(subFeatText[bestFeat])
#将要用于分类的特征在数据集中的所有数值写入列表
featValues = [example[bestFeat] for example in dataSet]
#使用 set 函数去重
uniqueVales = set(featValues)
#遍历所有剩余特征的文本表示
for value in uniqueVales:
#递归调用 creatTree 处理当前节点分支产生的子节点
myTree[bestFeatLabel][value] = creatTree(splitDataSet(dataSet,bestFeat,value),subFeatText)
return myTree
这里使用一个简易的数据集:
不浮出水面是否可以生存 | 是否有脚蹼 | label(属于鱼类) |
1 | 1 | yes |
1 | 1 | yes |
1 | 0 | no |
0 | 1 | no |
0 | 1 | no |
首先进入函数,获取 label 列表,判断未达到返回条件,计算出本次划分要基于哪个特征(这里有两个特征供选择):
featText:
['不浮出水面是否可以生存', '是否有脚蹼']
classList:
['yes', 'yes', 'no', 'no', 'no']
bestFeat:
0
bestFeatLabel:
不浮出水面是否可以生存
这里得到的结果是根据 特征[0] 进行划分,这一列特征的文本意义是“不浮出水面是否可以生存”,这样数据集就被分成了两个子集,删除已使用的特征,接着对分叉形成的各个子节点循环执行递归。
对第一个子集的递归,获取 label 列表,判断达到返回条件(所有 label 相同),直接返回:
featText:
['是否有脚蹼']
classList:
['no', 'no']
叶子节点:
no
对第二个子集的递归,获取 label 列表,判断未达到返回条件,计算出本次划分要基于哪个特征(这里只剩一个特征供选择):
featText:
['是否有脚蹼']
classList:
['yes', 'yes', 'no']
bestFeat:
0
bestFeatLabel:
是否有脚蹼
这里得到的结果是对 特征[0] 进行划分(前面使用过的特征已被删除),这一列特征的文本意义是“是否有脚蹼”,这样数据集就又被分成了两个子集,删除已使用的特征,接着对分叉形成的各个子节点循环执行递归。
对第一个子集的递归,获取 label 列表,判断达到返回条件(所有特征都已使用),直接返回:
featText:
[]
classList:
['no']
叶子节点:
no
对第二个子集的递归,获取 label 列表,判断达到返回条件(所有特征都已使用),直接返回:
featText:
[]
classList:
['yes', 'yes']
叶子节点:
yes
(注:如果使用完了所有的特征,出现的 label 还是不同,就会执行多数表决。)
因为是递归调用,所以返回的顺序也是自底向上的,首先返回子节点的树结构:
{'是否有脚蹼': {0: 'no', 1: 'yes'}}
最终返回得到的整个决策树:
{'不浮出水面是否可以生存': {0: 'no', 1: {'是否有脚蹼': {0: 'no', 1: 'yes'}}}}
Process finished with exit code 0
这里最终得到的树结构为:
结合程序和结构图仔细理解递归的处理过程,数据返回的顺序为:no surfacing 的 no 的叶子结点、flippers 的 no 的叶子结点、flippers 的 yes 的叶子结点、flippers 的子节点、no surfacing 的跟节点。
STEP5
使用训练的决策树进行分类
这里不熟悉的可以先看一下数据结构里面关于遍历树的内容。
#对测试数据进行分类
def classify(inputTree, featText, testVec):
#获取根节点的内容:第一个分类特征的文本描述
firstStr = list(inputTree.keys())[0]
#获取根节点下的字典结构
secondDict = inputTree[firstStr]
#获取根节点分类特征的下标
featIndex = featText.index(firstStr)
#循环处理根节点分叉的各个子节点
for key in secondDict.keys():
#找到该测试数据对应的分支节点
if testVec[featIndex] == key:
#如果该子节点不是叶子节点,递归查找
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key],featText,testVec)
#如果该子节点是叶子结点,返回
else:
classLabel = secondDict[key]
return classLabel
还是说明一下查找的过程:
dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
featText = ['不浮出水面是否可以生存','是否有脚蹼']
myTree = creatTree(dataSet, featText)
print(myTree)
print(classify(myTree,featText,[1,1]))
首先得到的决策树结构和之前描述的一致:
{'不浮出水面是否可以生存': {0: 'no', 1: {'是否有脚蹼': {0: 'no', 1: 'yes'}}}}
第一步,根据第一个分支进行判断,根据特征“不浮出水面是否可以生存”,该测试数据的这一项特征值为 1,特征为 1 的节点不是叶子结点,需要递归查找:
firstStr:
不浮出水面是否可以生存
secondDict:
{0: 'no', 1: {'是否有脚蹼': {0: 'no', 1: 'yes'}}}
key:
1
第二步,根据第一个分支进行判断,根据特征“是否有脚蹼”,该测试数据的这一项特征值为 1,特征为 1 的节点是叶子结点,得到结果:
firstStr:
是否有脚蹼
secondDict:
{0: 'no', 1: 'yes'}
key:
1
yes
ID3 算法存在过拟合问题,一个有效的处理方法是剪枝。