学习笔记:用决策树预测隐形眼镜类型

# -*- coding: UTF-8 -*-
from math import log
import operator
import matplotlib.pyplot as plt

def calcShannonEnt(dataSet): #计算信息增益
	numEntries = len(dataSet)
	labelCounts = {}
	for featVec in dataSet:
		currentLabel = featVec[-1]
		labelCounts[currentLabel] = labelCounts.get(currentLabel,0) + 1
	shannonEnt = 0.0
	for key in labelCounts:
		prob = float(labelCounts[key])/numEntries
		shannonEnt -= prob * log(prob, 2)
	return shannonEnt

def splitDataSet(dataSet, axis, valus): #划分数据集
	retDataSet = []
	for featVec in dataSet:
		if(featVec[axis] == valus):
			reducedFeatVec = featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)
	return retDataSet

def createDataSet(): #代码测试数据 
	dataSet = [[1, 1, 'yes'],
	[1,1,'yes'],
	[1,0,'no'],
	[0,1,'no'],
	[0,1,'no']]
	labels = ['no surfacing','flippers']
	return dataSet,labels

def chooseBestFeatureToSplit(dataSet): #选择最优特征
	numFeatures = len(dataSet[0]) - 1
	baseEntropy = calcShannonEnt(dataSet)
	bestInfoGain = 0.0; bestFeature = -1
	for i in range(numFeatures):
		featList = [example[i] for example in dataSet]
		uniqueVals = set(featList)
		newEntropy = 0.0
		for value in uniqueVals:
			subDataSet = splitDataSet(dataSet, i, value)
			prob = len(subDataSet)/float(len(dataSet))
			newEntropy += prob * calcShannonEnt(subDataSet)
		infoGain = baseEntropy - newEntropy
		print "%d,%f" % (i,infoGain)
		if (infoGain > bestInfoGain):
			bestInfoGain = infoGain
			bestFeature = i
	return bestFeature

def majorityCnt(classList): #当处理了所有元素,但类标签仍然不唯一
	classCount = {}
	for vote in classList:
		classCount[vote] = classCount.get(vote,0)+1
	sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
	return sortedClassCount[0][0]

def createTree(dataSet,labels):	#创建决策树
	classList = [example[-1] for example in dataSet]
	if classList.count(classList[0]) == len(classList): return classList[0]
	if len(dataSet[0]) == 1: return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataSet)
	print "bestFeat=%s" % labels[bestFeat]
	bestFeatLabel = labels[bestFeat]
	myTree = {bestFeatLabel:{}}
	del(labels[bestFeat])
	featValues = [example[bestFeat] for example in dataSet]
	uniqueVals = set(featValues)
	for value in uniqueVals:
		subLabels = labels[:]
		subDataSet = splitDataSet(dataSet, bestFeat, value)
		myTree[bestFeatLabel][value] = createTree(subDataSet, subLabels)
	return myTree

def plotNode(nodeTxt, centerPt, parentPt, nodeType): #画节点
	arrow_args = dict(arrowstyle="<-")
	createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',\
		va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) 

def getNumLeafs(myTree):#获得子树中叶子节点的数量
	numLeafs = 0
	firstStr = myTree.keys()[0]
	secondDict = myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__ == 'dict':
			numLeafs += getNumLeafs(secondDict[key])
		else: numLeafs += 1
	return numLeafs

def getTreeDepth(myTree):#获得树的深度
	maxDepth = 0
	firstStr = myTree.keys()[0]
	secondDict = myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__ == 'dict':
			thisDepth = 1 + getTreeDepth(secondDict[key])
		else: thisDepth = 1
		if thisDepth > maxDepth: maxDepth = thisDepth
	return maxDepth

def plotMidText(cntPt, parentPt, txtString):#在到叶子节点的边上上画上决策
	xMid = (parentPt[0]+cntPt[0])/2.0
	yMid = (parentPt[1]+cntPt[1])/2.0
	createPlot.ax1.text(xMid,yMid,txtString)

def plotTree(myTree, parentPt, nodeTxt):#画决策树
	decisionNode = dict(boxstyle="sawtooth", fc="0.8")
	leafNode = dict(boxstyle="round4", fc="0.8")

	if(type(myTree).__name__ != 'dict'):
		plotTree.xoff += 1.0/plotTree.totalw
		plotMidText((plotTree.xoff, plotTree.yoff), parentPt, nodeTxt)
		plotNode(myTree, (plotTree.xoff, plotTree.yoff), parentPt, leafNode)
	else:
		numLeafs = getNumLeafs(myTree)
		depth = getTreeDepth(myTree)
		cntPt = (plotTree.xoff + (1.0 + float(numLeafs))/2.0/plotTree.totalw, plotTree.yoff)
		firstStr = myTree.keys()[0]
		plotMidText(cntPt, parentPt, nodeTxt)
		plotNode(firstStr, cntPt, parentPt, decisionNode)
		secondDict = myTree[firstStr]
		plotTree.yoff -= 1.0/plotTree.totalD
		for key in secondDict.keys():
			plotTree(secondDict[key], cntPt, str(key))
		plotTree.yoff += 1.0/plotTree.totalD

def createPlot(inTree):#创建画布
	fig = plt.figure(1, facecolor='white')
	fig.clf()
	axprops = dict(xticks=[], yticks=[])
	createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
	plotTree.totalw = float(getNumLeafs(inTree))
	plotTree.totalD = float(getTreeDepth(inTree))
	plotTree.xoff = -0.5/plotTree.totalw; plotTree.yoff = 1.0;
	plotTree(inTree, (0.5, 1.0), '')
	plt.show()

def retrieveTree(i):#代码测试数据
	listOfTree = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},\
				  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
	return listOfTree[i]

def classify(inputTree, featLabels, testVec): #使用算法
	firstStr = inputTree.keys()[0]
	secondDict = inputTree[firstStr]
	featIndex = featLabels.index(firstStr)
	for key in secondDict.keys():
		if testVec[featIndex] == key:
			if type(secondDict[key]).__name__ == 'dict':
				classLabel = classify(secondDict[key], featLabels, testVec)
			else: classLabel = secondDict[key]
	return classLabel

def storeTree(inputTree, filename):#保存决策树
	import pickle
	fw = open(filename, 'w')
	pickle.dump(inputTree,fw)
	fw.close()

def grabTree(filename):#读取决策树
	import pickle
	fr = open(filename, 'r')
	return pickle.load(fr)

def file2matrix(filename): #导入数据
	fr = open(filename)
	arrayOfLines = fr.readlines()
	returnMat = [i.strip().split('\t') for i in arrayOfLines]
	labels = ['age', 'prescript', 'astigmatic', 'tearRate']
	return returnMat,labels

if __name__ == '__main__':
	datingDataMat,datingLabels = file2matrix('/Users/ZZ/Desktop/MY_FILE/MACHINE_LEARNING_IN_ACTION/machinelearninginaction/Ch03/lenses.txt')
	myTree = createTree(datingDataMat, datingLabels)
	createPlot(myTree)

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值