引用数据集获取:
如何构造一个决策树?
我们使用 createBranch() 方法,如下所示:
def createBranch():
'''
此处运用了迭代的思想。 感兴趣可以搜索 迭代 recursion, 甚至是 dynamic programing。
'''
检测数据集中的所有数据的分类标签是否相同:
If so return 类标签
Else:
寻找划分数据集的最好特征(划分之后信息熵最小,也就是信息增益最大的特征)
划分数据集
创建分支节点
for 每个划分的子集
调用函数 createBranch (创建分支的函数)并增加返回结果到分支节点中
return 分支节点
参看西瓜书76页的西瓜数据集2.0
def createDataList():
dataList = [
# 1
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
# 2
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
# 3
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
# 4
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
# 5
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
# 6
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
# 7
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
# 8
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
# ----------------------------------------------------
# 9
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
# 10
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
# 11
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
# 12
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
# 13
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
# 14
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
# 15
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
# 16
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
# 17
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
]
labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
return dataList, labels
def calcShannonEnt(dataList):
dataCount = len(dataList)
labelCounts = {}
for featVec in dataList:
currentLabel = featVec[-1]
labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / dataCount
shannonEnt -= prob * np.math.log(prob, 2)
return shannonEnt
按照给定特征划分数据集
def splitDataList(dataList, index, value):
retDataList = []
for featVec in dataList:
if featVec[index] == value:
reducedFeatVec = featVec[:index]
reducedFeatVec.extend(featVec[index + 1:])
retDataList.append(reducedFeatVec)
return retDataList
选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataList):
numFeatures = len(dataList[0]) - 1
baseEnt = calcShannonEnt(dataList)
bestInfoGain, bestFeature = 0.0, -1
for i in range(numFeatures):
featList = [example[i] for example in dataList]
uniqueVals = set(featList)
newEnt = 0.0
for value in uniqueVals:
subDataList = splitDataList(dataList, i, value)
prob = len(subDataList) / float(len(dataList))
newEnt += prob * calcShannonEnt(subDataList)
infoGain = baseEnt - newEnt
print('infoGain=', infoGain, 'bestFeature=', i, baseEnt, newEnt)
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
当划分到最后特征不一致时,少数服从多数
def majorityCnt(classList):
classCount = {}
for vote in classList:
classCount[vote] = classCount.get(vote, 0) + 1
sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0]
创建树
def createTree(dataList, labels):
classList = [example[-1] for example in dataList]
if classList.count(classList[0]) == len(classList) :
return classList[0]
if len(dataList[0]) == 1:
return majorityCnt(classList)
bestFeatIdx = chooseBestFeatureToSplit(dataList)
bestFeatLabel = labels[bestFeatIdx]
myTree = {bestFeatLabel : {}}
del(labels[bestFeatIdx])
featValues = [example[bestFeatIdx] for example in dataList]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataList(dataList, bestFeatIdx, value), subLabels)
return myTree
调用
if __name__ == '__main__':
dataList, labels = createDataList()
myTree = createTree(dataList, labels)
print(myTree)