这是以本人的笔记的形式写的,各个函数逐个来写,至于存放在那个模块大家可以看书,这里不再详细讲解。可能存在错误,有不对的的地方希望评论给予改正。多谢大家嘻嘻🤭
from math import log #这个是加载数学函数
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #计算这个数组的长度,注意这是个二维数组这个是计算里面有多少个小数组。
labelCounts = {} #创建一个字典
for featVec in dataSet: #遍历dataSet中的每一个小数组
currentLabel = featVec[-1] #取每个小数组的最后一个值
if currentLabel not in labelCounts.keys(): #判断labelCount有没有currentLabel键。如果没有,则把这个值作为这个函数的键,并将其值初始化为0.
labelCounts[currentLabel] = 0 #在上面一个已经解释。
labelCounts[currentLabel] += 1 #将这个键的值加1
shannonEnt = 0.0 #赋初始值
for key in labelCounts: #遍历每个键值
prob = float(labelCounts[key])/numEntries #将每个键出现的次数转换为浮点型,并除以总长度。
shannonEnt -= prob * log(prob,2) #求每个的期望值,并将它们相加。
return shannonEnt #返回所求的期望值和
这个函数的主要目的是求期望值。
def createDataSet(): #这个函数主要是为了提供数据。
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,2,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
这个函数的主要目的是保存数据
def splitDataSet(dataSet, axis, value): #三个参数第一个是二维数组,第二个是你要查找的位置,第三个是你要查找的对象。
retDataSet = [] #创建一个空的列表
for featVec in dataSet: 遍历dataSet中的每一个小列表
if featVec[axis] == value: #判断列表中axis对应的值是否等于value
reducedFeatVec = featVec[:axis] #这个和下面那个的主要目的就是为了去掉featVec中axis所对应的值。
reducedFeatVec.extend(featVec[axis+1:]) #这里要区分一下extend和append。例如a=[1,2,3],b=[3,4,5],a.append则为[1,2,3,[4,5,6]]而extend的值为[1,2,3,4,5,6]
retDataSet.append(reducedFeatVec) #这个是为了把变化后的每个小列表添加到retDataSet中
return retDataSet #返回变化后的值
这个函数的主要目的是对dataSet中的数据进行分类把位置axis对应的数据是否等于value进行分类并将axis这个位置的数据去掉。
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #求第一层的个数减1,这个主要是为了后面的遍历求解。
baseEntropy = calcShannonEnt(dataSet) #求dataSet的期望值
bestInfoGain = 0.0; bestFeature = -1 #赋初始值
for i in range(numFeatures): #循环遍历
featList = [example[i] for example in dataSet] #把dataSet中的数据存入到example中,依次读取example中的小列表。这种读取是对列的读取,就是保存所有的列。一列一列的保存。
uniqueVals = set(featList) #去重处理
newEntropy = 0.0 #给newEntropy赋初始值
for value in uniqueVals: #对uniqueVals里面的各个数进行遍历。
subDataSet = splitDataSet(dataSet, i, value) #求出每个数对应的期望值将其存入到subDataSet中
prob = len(subDataSet)/float(len(dataSet)) #求出这个标签所对应的属性与所有属性的比例
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy #这个代码和前两个代码统称为信息增益,信息增益就是有一个固定熵减去一个条件熵下面是求得的最大熵。
if (infoGain > bestInfoGain):
bestInfoGain = infoGain #将最大的熵赋值bestInfoGain
bestFeature = i #并将序号保存
return bestFeature #返回序号
求最佳方案
def majorityCnt(classList):
classCount = {} #创建一个空字典
for vote in classList: #访问classList中的每个元素
if vote not in classCount.keys(): classCount[vote] = 0 #判断classCount中是否含有vote所对应的值,如果没有则将value所对应的值用作键,并将其值赋值为0;
classCount[vote] += 1 #把所对应的键值加1
sortedClassCount = sorted(classCount, key=operator.itemgetter(1), reverse=True) #排序前面我的第一张有讲,这里不做陈述。
return sortedClassCount[0][0] #返回占有率最大的那个
这个函数主要是分类找到距离最近的
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet] #前面有讲
if classList.count(classList[0]) == len(classList): #这个的目的就是为了判断classList是否只有一个类别,如果是则停止划分。
return classList[0]
if len(dataSet[0]) == 1: #遍历完所有类型仍没有结束
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) #求出最佳方案输出其列序号
bestFeatLabel = labels[bestFeat] #特征对应的标签
myTree = {bestFeatLabel:{}} #建立树节点即维度标签,并赋予空值
del(labels[bestFeat]) #从标签列表中删除已选维度标签
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) #同上就是筛选掉重复的值
for value in uniqueVals: #将最优维度对应每个值,一个值对应一个分支。
subLabels = labels[:] #将labels的值赋值给subLabels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) #递归建树
return myTree #返回树
以下函数每次修改都需要转换代码,我的是这样的。不会转换的可以去网上搜,也可以私聊。
#! D:/software/anaconda3/python.exe
#-*- coding:UTF-8 -*- #这两个是为了将中文字符串读出来,还有就是要把这段代码转换为utf-8,要不然会出现SyntaxError: (unicode error) 'utf-8' codec can't decode byte 0xbe in position 0: invalid start byte这种错误
import matplotlib.pyplot as plt #加载绘图函数
decisionNode = dict(boxstyle = "sawtooth", fc="0.8") #设置小框框的的样式,小框框中的背景色深度
leafNode = dict(boxstyle="round4", fc="0.8") #与上面中的一样
arrow_args = dict(arrowstyle="<-") #设置箭头样式
这个主要是为了设置框的形式,框内背景色,箭头的样式注意箭头的方向
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#这个函数是为了执行实际的绘图功能。
def createPlot():
fig=plt.figure(1, facecolor='white') #这个是设置背景色,是整个图的背景色。
fig.clf() #清空画布
createPlot.ax1 = plt.subplot(111, frameon=False) #用来定义全局变量绘图区
plotNode("决策节点" , (0.5, 0.1), (0.1, 0.5), decisionNode) 调用函数绘图
plotNode('叶节点' , (0.8, 0.1), (0.3, 0.8), leafNode) #调用函数绘图
plt.show() #展示图像
#这个函数主要是为了绘制图像
def getNumLeafs(myTree):
numLeafs = 0 #赋初始值求叶节点
firstStr = list(myTree.keys())[0] #求出第一颗树的根节点
secondDict = myTree[firstStr] #获得第一个子节点
for key in secondDict.keys(): #循环遍历每个子结点的根节点
if type (secondDict[key]).__name__=='dict': #如果这个节点不是叶子节点则递归继续寻找
numLeafs += getNumLeafs(secondDict[key])
else: #否则对应值加1
numLeafs += 1
return numLeafs
#这个函数是为了求叶子节点的个数
def getTreeDepth(myTree):
maxDepth = 0 #赋初始值求深度
firstStr = list(myTree.keys())[0] #与上一个函数一样
secondDict = myTree[firstStr] #与上一个函数一样
for key in secondDict.keys(): #同上
if type(secondDict[key]).__name__=='dict': 如果这个节点不是叶子节点,则1加上其递归值
thisDepth = 1 + getTreeDepth(secondDict[key])
else: #否则其值为1
thisDepth = 1
if thisDepth > maxDepth: #取最大的深度
maxDepth = thisDepth #如果是最大的深度则将次深度赋值为最大值
return maxDepth #返回最大值
这个函数是为了求最大深度
def retrieveTree(i):
listOfTree = [{'no surfacing': {0:'no', 1: {'flippers': {0:'no', 1:'yes'}}}},
{'no surfacing': {0: 'no', 1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
]
return listOfTree[i]
这个函数只要是为了创建验证数据
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
#用父子节点填充文本信息
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) #获取树的叶子个数,计算树的宽
depth = getTreeDepth(myTree) #获取树的深度,计算树的高
firstStr = list(myTree.keys())[0] #获取第一个键
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #计算当前的节点数
plotMidText(cntrPt, parentPt, nodeTxt) #调用plotMidText是为了绘出节点当前具有的特征值。计算父子节点的中间位置,并在此填充文本
plotNode(firstStr, cntrPt, parentPt, decisionNode) #调用plotNode是为了
secondDict = myTree[firstStr] 获取键内的树
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #因从上往下画因此需要递减y的值
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #增加全局变量X的偏移量
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #上面有解释
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) #上面有解释
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #增加全局变量Y的偏移量
def createPlot(inTree):
fig = plt.figure(1, facecolor = 'white') #前面有讲
fig.clf() #前面有讲
axprops = dict(xticks=[], yticks=[]) #设置横纵坐标
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #由全局变量createPlot.ax1绘图111表示一行一列第一个,frameon表示边框,**axprops表示是否显示刻度
plotTree.totalW = float(getNumLeafs(inTree)) #获取叶子的个数
plotTree.totalD = float(getTreeDepth(inTree)) #获取最大深度,这两个全局变量主要为了使树能有一个好的摆放
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #求出初始的全局变量xOff 和yOff,追踪已经绘制的点,以及下一个要绘制的恰当位置
plotTree(inTree, (0.5, 1.0), '') 调用函数绘制图像
plt.show() #显示图像
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0] #获取该字典的第一个键
secondDict = inputTree[firstStr] #获取相应键所对应的值
featIndex = featLabels.index(firstStr) #index列表查找第一个匹配firstStr的索引。所谓索引就是你能通过位置
for key in secondDict.keys(): #遍历第一个键的值中的键
if testVec[featIndex] == key: 这个是为了遍历树中的每个子根,直到遍历你想查找的是对应的值,也就是testVec[featIndex]对应的值等于key。
if type(secondDict[key]).__name__=='dict': #判断你所找到的键是不是字典,如果是则继续递归,不是则输出结果。
classLabel = classify(secondDict[key], featLabels, testVec) #递归用的
else:
classLabel = secondDict[key] 输出相应的值
return classLabel #返回相应的值
这个函数的目的是分类,寻找叶子节点,遍历叶子节点,返回最终的寻找值
这个函数目的分类标签求出相应的叶节点
pickle模块可以序列化列表,字典,集合,类等,将这些以文件的形式存入磁盘,但是人们不易阅读。
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'wb+') #创建一个新文件这里的wb+是因为我要输入字符串,如果是数字则改成w即可。
pickle.dump(inputTree, fw) #写入新创建文件, 序列化对象
fw.close() #保存文件
这个函数主要是为了创建文件,序列化对象,存入数据
def grabTree(filename):
import pickle
fr = open(filename, 'rb') #打开文件注意这里的rb,如果我输入的不是字符串则不能这么写
return pickle.load(fr) #返回加载后的内容,也就是反序列化
这个函数的目的读取文件,反序列化,返回读取后的结果。