python 决策树 math库 c45算法

每周一搏,提升自我。

这段时间对python的应用,对python的理解越来越深。摸索中修改网上实例代码,有了自己的理解。

c45是ID3算法的升级版,比ID3高级。个人建议,用CART算法,感觉比C45好。

下面是c45代码,其中显示决策树结构的代码,下篇博文发布。

#!/usr/bin/python
#coding:utf-8

import operator
from math import log
import time
import os,sys
import string

#已文件为数据源
def createDataSet(trainDataFile):
	print trainDataFile
	dataSet=[]
	try:
		fin=open(trainDataFile)
		for line in fin:
			line=line.strip('\n')  #清除行皆为换行符
			cols=line.split(',')  #逗号分割行信息
			row =[cols[1],cols[2],cols[3],cols[4],cols[5],cols[6],cols[7],cols[8],cols[9],cols[10],cols[0]]
			dataSet.append(row)
			#print row
	except:
		print 'Usage xxx.py trainDataFilePath'
		sys.exit()
	labels=['cip1', 'cip2', 'cip3', 'cip4', 'sip1', 'sip2', 'sip3', 'sip4', 'sport', 'domain']
	print 'dataSetlen',len(dataSet)
	return dataSet,labels

#c4.5 信息熵算法
def calcShannonEntOfFeature(dataSet,feat):
	numEntries=len(dataSet)
	labelCounts={}
	for feaVec in dataSet:
		currentLabel=feaVec[feat]
		if currentLabel not in labelCounts:
			labelCounts[currentLabel]=0
		labelCounts[currentLabel]+=1
	shannonEnt=0.0
	for key in labelCounts:
		prob=float(labelCounts[key])/numEntries
		shannonEnt-=prob * log(prob,2)
	return shannonEnt

def splitDataSet(dataSet,axis,value):
	retDataSet=[]
	for featVec in dataSet:
		if featVec[axis] ==value:
			reducedFeatVec=featVec[:axis]
			reducedFeatVec.extend(featVec[axis+1:])
			retDataSet.append(reducedFeatVec)
	return retDataSet

def chooseBestFeatureToSplit(dataSet):
	numFeatures=len(dataSet[0])-1
	baseEntropy=calcShannonEntOfFeature(dataSet,-1)
	bestInfoGainRate=0.0
	bestFeature=-1
	for i in range(numFeatures):
		featList=[example[i] for example in dataSet]
		uniqueVals=set(featList)
		newEntropy=0.0
		for value in uniqueVals:
			subDataSet=splitDataSet(dataSet,i,value)
			prob=len(subDataSet) / float(len(dataSet))
			newEntropy+=prob * calcShannonEntOfFeature(subDataSet,-1)
		infoGain=baseEntropy- newEntropy
		iv = calcShannonEntOfFeature(dataSet,i)
		if(iv == 0):
			continue
		infoGainRate= infoGain /iv
		if infoGainRate > bestInfoGainRate:
			bestInfoGainRate = infoGainRate
			bestFeature = i
	return bestFeature

def majorityCnt(classList):
	classCount={}
	for vote in classList:
		if vote not in classCount.keys():
			classCount[vote]=0
		classCount[vote] +=1
	return max(classCount)


def createTree(dataSet,labels):
	classList= [example[-1] for example in dataSet]
	if classList.count(classList[0]) == len(classList):
		return classList[0]
	if len(dataSet[0]) == 1:
		return majorityCnt(classList)
	bestFeat = chooseBestFeatureToSplit(dataSet)
	bestFeatLabel = labels[bestFeat]
	if(bestFeat == -1): #特征一样,但类别不一样,即类别与特征不相关,随机选第一个类别分类结果
		return classList[0]
	myTree={bestFeatLabel:{}}
	del(labels[bestFeat])
	featValues =  [example[bestFeat] for example in dataSet]
	uniqueVals =set(featValues)
	for value in uniqueVals:
		subLabels = labels [:]
		myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
	return myTree


#创建简单的数据集   武器类型(0 步枪 1机枪),子弹(0 少 1多),血量(0 少,1多)  fight战斗 1逃跑 
def createDataSet():
	dataSet =[[1,1,0,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[0,0,1,'run'],[0,1,0,'fight'],[0,1,1,'run']]
	lables=['weapon','bullet','blood']
	return dataSet,lables

#按行打印数据集
def printData(myData):
	for item in myData:
		print '%s' %(item)

#使用决策树分类
def classify(inputTree,featLabels,testVec):
	firstStr=inputTree.keys()[0]
	secondDict=inputTree[firstStr]
	featIndex=featLabels.index(firstStr)
	for key in secondDict.keys():
		if testVec[featIndex] ==key:
			if type(secondDict[key]).__name__=='dict':
				classLabel=classify(secondDict[key],featLabels,testVec)
			else:classLabel=secondDict[key]
	return classLabel

#存储决策树
def storeTree(inputTree,filename):
	import pickle
	fw=open(filename,'w')
	pickle.dump(inputTree,fw)
	fw.close()


#获取决策树
def grabTree(filename):
	import pickle
	fr=open(filename)
	return pickle.load(fr)



def main():
	data,label =createDataSet()
	myTree=createTree(data,label)
	print(myTree)

	#打印决策树
	import showTree as show
	show.createPlot(myTree)


if __name__ == '__main__':
	main()

调用的showTree.py,内容如下:

#!/usr/bin/python
#coding:utf-8

import matplotlib.pyplot as plt

#决策树属性设置
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")


#createPlot 主函数,调用即可画出决策树,其中调用登了剩下的所有的函数,inTree的形式必须为嵌套的决策树
def createPlot(inThree):
	fig=plt.figure(1,facecolor='white')
	fig.clf()
	axprops=dict(xticks=[],yticks=[])
	createPlot.ax1=plt.subplot(111,frameon=False,**axprops)  #no ticks
	# createPlot.ax1=plt.subplot(111,frameon=False)  #ticks for demo puropses
	plotTree.totalW=float(getNumLeafs(inThree))
	plotTree.totalD=float(getTreeDepth(inThree))
	plotTree.xOff=-0.5/plotTree.totalW;
	plotTree.yOff=1.0
	plotTree(inThree,(0.5,1.0),'')
	plt.show()

#决策树上节点之间的箭头设置
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
	createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',
		xytext=centerPt,textcoords='axes fraction',
		va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)

#决策树文字的添加位置和角度
def plotMidText(cntrPt,parentPt,txtString):
	xMid=(parentPt[0] -cntrPt[0])/2.0 +cntrPt[0]
	yMid=(parentPt[1] -cntrPt[1])/2.0 +cntrPt[1]
	createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)

#得到叶子节点的数量
def getNumLeafs(myTree):
	numLeafs=0
	firstStr=myTree.keys()[0]
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
			numLeafs += getNumLeafs(secondDict[key])
		else: numLeafs+=1
	return numLeafs

#得到决策树的深度
def getTreeDepth(myTree):
	maxDepthh=0
	firstStr=myTree.keys()[0]
	secondDict=myTree[firstStr]
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			thisDepth=1+getTreeDepth(secondDict[key])
		else: thisDepth=1
		if thisDepth>maxDepthh:maxDepthh=thisDepth
	return maxDepthh

#父子节点之间画决策树
def plotTree(myTree,parentPt,nodeTxt):
	numLeafs=getNumLeafs(myTree)
	depth=getTreeDepth(myTree)
	firstStr=myTree.keys()[0]
	cntrPt=(plotTree.xOff +(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)
	plotMidText(cntrPt,parentPt,nodeTxt)
	plotNode(firstStr,cntrPt,parentPt,decisionNode)
	secondDict=myTree[firstStr]
	plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':
			plotTree(secondDict[key],cntrPt,str(key))
		else:
			plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW
			plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
			plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))
	plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD

 

转载于:https://my.oschina.net/wangzonghui/blog/1617580

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值