from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
# print("numEntries:", numEntries)
labelCounts = {}
count = 0
for featVec in dataSet:
# count += 1
# print("第 %d次:" % count)
currentLabel = featVec[-1]
# print("currentLabel:", currentLabel)
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
# 根据标签做个统计,每个键值都记录了当前类别出现的次数。
# print("labelCounts:", labelCounts)
count = 0
for key in labelCounts:
# count += 1
# print("第 %s类:" % key)
prob = float(labelCounts[key])/numEntries
# print("prob:", prob)
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def createDataSet():
dataSet = [[1, 1, "yes"],
[1, 1, "yes"],
[1, 0, "no"],
[0, 1, "no"],
[0, 1, "no"]]
labels = ["no surfacing", "flippers"]
return dataSet, labels
def splitDataSet(dataSet, axis, value):
retDataSet = []
count = 0
for feavVec in dataSet:
# count += 1
# print("第 %d次:" % count)
# print("feavVec:", feavVec)
# print("feavVec[axis]:", feavVec[axis])
# print("feavVec[:axis]:", feavVec[:axis])
if feavVec[axis] == value:
reduceFeatVec = feavVec[:axis]
# print("reduceFeatVec:", reduceFeatVec)
# print("feavVec[axis+1:]:", feavVec[axis+1:])
reduceFeatVec.extend(feavVec[axis+1:])
retDataSet.append(reduceFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
# print("dataSet:", dataSet)
# print("dataSet[0]:", dataSet[0])
numFeatures = len(dataSet[0]) - 1
# print("numFeatures:", numFeatures)
baseEntroy = calcShannonEnt(dataSet)# 计算给定数据集的熵。
# print("baseEntroy:", baseEntroy)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
# print("第 %d次:" % i)
featList = [example[i] for example in dataSet]
# print("featList:", featList)
uniqueVals = set(featList)# 放在集合中,去重
# print("uniqueVals:", uniqueVals)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntroy - newEntropy
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
import operator
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
print("classList:", classList)
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
import matplotlib.pyplot as plt
# 定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
# 绘制箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.axl.annotate(nodeTxt, centerPt, parentPt, nodeType, xycoords = "axes fraction",
xytext = centerPt, textcoords = "axes fraction", va = "center",
ha = "center", bbox = "nodeType", arrowprops = arrow_args)
def createPlot():
fig = plt.figure(1, facecolor="white")
fig.clf()
createPlot.ax1 = plt.subplot(111, frameon=False)
plotNode("决策节点", (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode("叶节点", (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth += 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{"no surfacing": {0: "no", 1: {"flippers":{0: "no", 1 :"yes"}}}},
{"no surfacing": {0: "no", 1:{"flippers":{0: {"head": {0: "no", 1: "yes"}}, 1: "no"}}}}]
if __name__ == "__main__":
myDat, labels = createDataSet()
print("myDat:", myDat)
print("calcShannonEnt(myDat):", calcShannonEnt(myDat))
print("----------------增加类别---------------------------------")
myDat[0][-1] = "maybe"
print("myDat:", myDat)
print("calcShannonEnt(myDat):", calcShannonEnt(myDat))
print("----------------splitDataSet---------------------------------")
a = [1,2,3]
b=[4,5,6]
a.append(b)
print("a:", a)
a = [1,2,3]
a.extend(b)
print("a:", a)
myDat, labels = createDataSet()
print("myDat:", myDat)
temp = splitDataSet(myDat, 0, 0)
print("temp(0,0):", temp)
print("----------------splitDataSet---------------------------------")
temp = splitDataSet(myDat, 0, 1)
print("temp(0,1):", temp)
print("----------------chooseBestFeatureToSplit---------------------------------")
myDat, labels = createDataSet()
print("myDat:", myDat)
temp = chooseBestFeatureToSplit(myDat)
print("temp:", temp)
# dataSet = myDat
# featList = [example[1] for example in dataSet]
# print("featList---:", featList)
# for example in dataSet:
# print("example:", example)
print("----------------createTree---------------------------------")
myDat, labels = createDataSet()
myTree = createTree(myDat, labels)
print("myTree:", myTree)
print("----------------createPlot---------------------------------")
# createPlot()
机器学习实战-第三章 决策树算法
最新推荐文章于 2021-10-29 00:46:50 发布