机器学习实战第三章决策树

这是以本人的笔记的形式写的,各个函数逐个来写,至于存放在那个模块大家可以看书,这里不再详细讲解。可能存在错误,有不对的的地方希望评论给予改正。多谢大家嘻嘻🤭

from math import log      #这个是加载数学函数

def calcShannonEnt(dataSet):
	numEntries = len(dataSet)      #计算这个数组的长度,注意这是个二维数组这个是计算里面有多少个小数组。
	labelCounts = {}      #创建一个字典
	for featVec in dataSet:      #遍历dataSet中的每一个小数组
		currentLabel = featVec[-1]      #取每个小数组的最后一个值
		if currentLabel not in labelCounts.keys():      #判断labelCount有没有currentLabel键。如果没有,则把这个值作为这个函数的键,并将其值初始化为0.
			labelCounts[currentLabel] = 0      #在上面一个已经解释。
		labelCounts[currentLabel] += 1      #将这个键的值加1
	shannonEnt = 0.0      #赋初始值
	for key in labelCounts:      #遍历每个键值
		prob = float(labelCounts[key])/numEntries      #将每个键出现的次数转换为浮点型,并除以总长度。
		shannonEnt -= prob * log(prob,2)      #求每个的期望值,并将它们相加。
	return shannonEnt      #返回所求的期望值和
	这个函数的主要目的是求期望值。
def createDataSet():      #这个函数主要是为了提供数据。 
	dataSet = [[1,1,'yes'],
			[1,1,'yes'],
			[1,2,'no'],
			[0,1,'no'],
			[0,1,'no']]
			
	labels = ['no surfacing', 'flippers']
	return dataSet, labels
	这个函数的主要目的是保存数据
def splitDataSet(dataSet, axis, value):      #三个参数第一个是二维数组,第二个是你要查找的位置,第三个是你要查找的对象。
	retDataSet = []      #创建一个空的列表
	for featVec in dataSet:      遍历dataSet中的每一个小列表
		if featVec[axis] == value:      #判断列表中axis对应的值是否等于value
			reducedFeatVec = featVec[:axis]      #这个和下面那个的主要目的就是为了去掉featVec中axis所对应的值。
			reducedFeatVec.extend(featVec[axis+1:])      #这里要区分一下extend和append。例如a=[1,2,3],b=[3,4,5],a.append则为[1,2,3,[4,5,6]]而extend的值为[1,2,3,4,5,6]
			retDataSet.append(reducedFeatVec)      #这个是为了把变化后的每个小列表添加到retDataSet中
	return retDataSet      #返回变化后的值
	这个函数的主要目的是对dataSet中的数据进行分类把位置axis对应的数据是否等于value进行分类并将axis这个位置的数据去掉。
def chooseBestFeatureToSplit(dataSet):
	numFeatures = len(dataSet[0]) - 1      #求第一层的个数减1,这个主要是为了后面的遍历求解。
	baseEntropy = calcShannonEnt(dataSet)      #求dataSet的期望值
	bestInfoGain = 0.0; bestFeature = -1      #赋初始值
	for i in range(numFeatures):      #循环遍历
		featList = [example[i] for example in dataSet]      #把dataSet中的数据存入到example中,依次读取example中的小列表。这种读取是对列的读取,就是保存所有的列。一列一列的保存。
		uniqueVals = set(featList)      #去重处理
		newEntropy = 0.0      #给newEntropy赋初始值
		for value in uniqueVals:      #对uniqueVals里面的各个数进行遍历。
			subDataSet = splitDataSet(dataSet, i, value)      #求出每个数对应的期望值将其存入到subDataSet中
			prob = len(subDataSet)/float(len(dataSet))      #求出这个标签所对应的属性与所有属性的比例
			newEntropy += prob * calcShannonEnt(subDataSet)
		infoGain = baseEntropy - newEntropy      #这个代码和前两个代码统称为信息增益,信息增益就是有一个固定熵减去一个条件熵下面是求得的最大熵。
		if (infoGain > bestInfoGain):
			bestInfoGain = infoGain      #将最大的熵赋值bestInfoGain
			bestFeature = i      #并将序号保存
	return bestFeature      #返回序号
求最佳方案
def majorityCnt(classList):
	classCount = {}      #创建一个空字典
	for vote in classList:      #访问classList中的每个元素
		if vote not in classCount.keys(): classCount[vote] = 0      #判断classCount中是否含有vote所对应的值,如果没有则将value所对应的值用作键,并将其值赋值为0;
		classCount[vote] += 1      #把所对应的键值加1
	sortedClassCount = sorted(classCount, key=operator.itemgetter(1), reverse=True)      #排序前面我的第一张有讲,这里不做陈述。
	return sortedClassCount[0][0]      #返回占有率最大的那个
	这个函数主要是分类找到距离最近的
def createTree(dataSet, labels):
	classList = [example[-1] for example in dataSet]      #前面有讲
	if classList.count(classList[0]) == len(classList):      #这个的目的就是为了判断classList是否只有一个类别,如果是则停止划分。
		return classList[0]
	if len(dataSet[0]) == 1:      #遍历完所有类型仍没有结束
		return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataSet)      #求出最佳方案输出其列序号
	bestFeatLabel = labels[bestFeat]      #特征对应的标签
	myTree = {bestFeatLabel:{}}      #建立树节点即维度标签,并赋予空值
	del(labels[bestFeat])      #从标签列表中删除已选维度标签
	featValues = [example[bestFeat] for example in dataSet]
	uniqueVals = set(featValues)      #同上就是筛选掉重复的值
	for value in uniqueVals:      #将最优维度对应每个值,一个值对应一个分支。
		subLabels = labels[:]      #将labels的值赋值给subLabels
		myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)      #递归建树
	return myTree      #返回树
以下函数每次修改都需要转换代码,我的是这样的。不会转换的可以去网上搜,也可以私聊。
#! D:/software/anaconda3/python.exe
#-*- coding:UTF-8 -*-      #这两个是为了将中文字符串读出来,还有就是要把这段代码转换为utf-8,要不然会出现SyntaxError: (unicode error) 'utf-8' codec can't decode byte 0xbe in position 0: invalid start byte这种错误

import matplotlib.pyplot as plt      #加载绘图函数

decisionNode = dict(boxstyle = "sawtooth", fc="0.8")      #设置小框框的的样式,小框框中的背景色深度
leafNode = dict(boxstyle="round4", fc="0.8")      #与上面中的一样
arrow_args = dict(arrowstyle="<-")      #设置箭头样式
这个主要是为了设置框的形式,框内背景色,箭头的样式注意箭头的方向

def plotNode(nodeTxt, centerPt, parentPt, nodeType):      
	createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', 
							xytext=centerPt, textcoords='axes fraction', 
							va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#这个函数是为了执行实际的绘图功能。 
def createPlot():
	fig=plt.figure(1, facecolor='white')      #这个是设置背景色,是整个图的背景色。
	fig.clf()      #清空画布
	createPlot.ax1 = plt.subplot(111, frameon=False)      #用来定义全局变量绘图区
	plotNode("决策节点" , (0.5, 0.1), (0.1, 0.5), decisionNode)      调用函数绘图
	plotNode('叶节点' , (0.8, 0.1), (0.3, 0.8), leafNode)      #调用函数绘图
	plt.show()      #展示图像
	#这个函数主要是为了绘制图像
def getNumLeafs(myTree):
	numLeafs = 0      #赋初始值求叶节点
	firstStr = list(myTree.keys())[0]      #求出第一颗树的根节点
	secondDict = myTree[firstStr]      #获得第一个子节点
	for key in secondDict.keys():       #循环遍历每个子结点的根节点
		if type (secondDict[key]).__name__=='dict':      #如果这个节点不是叶子节点则递归继续寻找
			numLeafs += getNumLeafs(secondDict[key])
		else:      #否则对应值加1
			numLeafs += 1
	return numLeafs 
	#这个函数是为了求叶子节点的个数
def getTreeDepth(myTree):
	maxDepth = 0      #赋初始值求深度
	firstStr = list(myTree.keys())[0]      #与上一个函数一样
	secondDict = myTree[firstStr]      #与上一个函数一样
	for key in secondDict.keys():      #同上
		if type(secondDict[key]).__name__=='dict':      如果这个节点不是叶子节点,则1加上其递归值
			thisDepth = 1 + getTreeDepth(secondDict[key])
		else:      #否则其值为1
			thisDepth = 1
		if thisDepth > maxDepth:      #取最大的深度
			maxDepth = thisDepth      #如果是最大的深度则将次深度赋值为最大值
	return maxDepth      #返回最大值
	这个函数是为了求最大深度
def retrieveTree(i):
	listOfTree = [{'no surfacing': {0:'no', 1: {'flippers': {0:'no', 1:'yes'}}}},
					{'no surfacing': {0: 'no', 1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
					]
	return listOfTree[i]
	这个函数只要是为了创建验证数据
def plotMidText(cntrPt, parentPt, txtString):
	xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
	yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
	createPlot.ax1.text(xMid, yMid, txtString)
	#用父子节点填充文本信息
def plotTree(myTree, parentPt, nodeTxt):
	numLeafs = getNumLeafs(myTree)      #获取树的叶子个数,计算树的宽
	depth = getTreeDepth(myTree)      #获取树的深度,计算树的高
	firstStr = list(myTree.keys())[0]      #获取第一个键
	cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)      #计算当前的节点数
	plotMidText(cntrPt, parentPt, nodeTxt)      #调用plotMidText是为了绘出节点当前具有的特征值。计算父子节点的中间位置,并在此填充文本
	plotNode(firstStr, cntrPt, parentPt, decisionNode)      #调用plotNode是为了
	secondDict = myTree[firstStr]      获取键内的树
	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD      #因从上往下画因此需要递减y的值
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree(secondDict[key], cntrPt, str(key))
		else:
			plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW      #增加全局变量X的偏移量
			plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)    #上面有解释
			plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    #上面有解释
	plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD      #增加全局变量Y的偏移量
	
def createPlot(inTree):
	fig = plt.figure(1, facecolor = 'white')      #前面有讲
	fig.clf()      #前面有讲
	axprops = dict(xticks=[], yticks=[])    #设置横纵坐标
	createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)      #由全局变量createPlot.ax1绘图111表示一行一列第一个,frameon表示边框,**axprops表示是否显示刻度 
	plotTree.totalW = float(getNumLeafs(inTree))      #获取叶子的个数
	plotTree.totalD = float(getTreeDepth(inTree))      #获取最大深度,这两个全局变量主要为了使树能有一个好的摆放
	plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;      #求出初始的全局变量xOff 和yOff,追踪已经绘制的点,以及下一个要绘制的恰当位置
	plotTree(inTree, (0.5, 1.0), '')      调用函数绘制图像
	plt.show()      #显示图像
def classify(inputTree, featLabels, testVec):
	firstStr = list(inputTree.keys())[0]      #获取该字典的第一个键
	secondDict = inputTree[firstStr]      #获取相应键所对应的值
	featIndex = featLabels.index(firstStr)      #index列表查找第一个匹配firstStr的索引。所谓索引就是你能通过位置
	for key in secondDict.keys():      #遍历第一个键的值中的键
		if testVec[featIndex] == key:      这个是为了遍历树中的每个子根,直到遍历你想查找的是对应的值,也就是testVec[featIndex]对应的值等于key。
		
			if type(secondDict[key]).__name__=='dict':      #判断你所找到的键是不是字典,如果是则继续递归,不是则输出结果。
				classLabel = classify(secondDict[key], featLabels, testVec)      #递归用的
			else:
				classLabel = secondDict[key]      输出相应的值
	return classLabel      #返回相应的值
	这个函数的目的是分类,寻找叶子节点,遍历叶子节点,返回最终的寻找值
	这个函数目的分类标签求出相应的叶节点
pickle模块可以序列化列表,字典,集合,类等,将这些以文件的形式存入磁盘,但是人们不易阅读。
def storeTree(inputTree, filename):
	import pickle
	fw = open(filename, 'wb+')      #创建一个新文件这里的wb+是因为我要输入字符串,如果是数字则改成w即可。
	pickle.dump(inputTree, fw)      #写入新创建文件, 序列化对象
	fw.close()      #保存文件
	这个函数主要是为了创建文件,序列化对象,存入数据
def grabTree(filename):
	import pickle
	fr = open(filename, 'rb')      #打开文件注意这里的rb,如果我输入的不是字符串则不能这么写
	return pickle.load(fr)      #返回加载后的内容,也就是反序列化
	这个函数的目的读取文件,反序列化,返回读取后的结果。
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值