python实现ID3

python实现ID3

具体的决策树原理在此就不再赘述,可自行百度或者看我之前写的:https://blog.csdn.net/weixin_38273255/article/details/88752468
在这里主要列出我使用的代码,和一些学习时候的心得。

心得

代码在网上其实都有一大片一大片的,但是自己觉得不太符合自己想要的要求,所以就找了网上的代码,并做了一些修改。很多地方肯定会有毛病,还望见谅。
我想要的是能够对连续型数据做处理,但是ID3是对离散型数据做处理,网上的示例代码也是如此,之前在学习原理的时候看到说可以对连续数据做离散处理:

  • 对连续数据排序
  • 将排好的序列等分(让每个区间落入的数据量基本一致)
  • 重新对区间赋值
  • 离散化完成

可能这个也存在问题,但是就这么实现了。

代码
  1. 导入数据
  2. 对数据离散化
  3. 创建决策树
  4. 绘制决策树图像
  5. 输入测试集测试

下下面就列出我的代码,代码主要是对iris数据集做出了处理,其他数据集可自行修改。
main.py

# -*- coding: utf-8 -*-
import creat as ct
import draw as dr
import yanzheng as yz
from sklearn import datasets
import numpy as np 
import lisanhua as lsh
import random

#加载iris数据集
iris = datasets.load_iris()
all_data = iris.data[:,:]
all_target = iris.target[:]
labels = iris.feature_names[:]

#常量定义
n = 150#数据集总数
m = int(n*2/3)#创建用的数据量
q = 4#数据维度
l = 7#离散化个数

#对数据离散化
a = []
all_data,a = lsh.lsh(all_data,l)

#将target和数据合并
all_data = all_data.tolist()
all_target = all_target.tolist()
for i in range(len(all_target)):
	all_data[i].append(all_target[i])

#将数据打乱
random.shuffle(all_data)

#创建决策树数据集
cj_data = all_data[:m]

#创建决策树
myTree=ct.createTree(cj_data,labels)

#创建验证数据集
all_data = np.array(all_data)#转化为numpy
yz_target = np.array(all_data[m:n,q:q+1])
yz_data = np.array(all_data[m:n,:q])
yz_labels = np.array(iris.feature_names[:])

#验证决策树正确率
yz_shu = yz.yanzheng(myTree,yz_data,yz_labels,yz_target)
yz_bfb = float(yz_shu)/(n-m)

#结果反馈
print(myTree)
print(yz_shu)
print(yz_bfb)
dr.createPlot(myTree)

lisanhua.py

import numpy as np

def lsh(data,num):
	a = []
	for i in range(len(data[0])):
		b = []
		data1 = data[:,i]
		l = len(data1)
		data1.sort()
		for k in range(num):
			b.append(data1[int(k*l/num)])
		for j in range(len(data)):
			if data[j,i] >= b[-1]:
				data[j,i] = b[-1]
				continue
			for q in range(1,num):
				if data[j,i] < b[q] and data[j,i] >= b[q-1]:
					data[j,i] = b[q-1]
		a.append(b)
	return data,a

creat.py

	import math
import operator
 
def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] +=1
         
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*math.log(prob, 2)
    return shannonEnt
     
def CreateDataSet():
    dataset = [[1, 1, 'yes' ],
               [1, 1, 'yes' ],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataset, labels
 
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
     
    return retDataSet
 
def chooseBestFeatureToSplit(dataSet):
    numberFeatures = len(dataSet[0])-1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0;
    bestFeature = -1;
    for i in range(numberFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy =0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
 
def majorityCnt(classList):
    classCount ={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]=1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
  
 
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

draw.py


# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")
 
 
#计算树的叶子节点数量
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs+=getNumLeafs(secondDict[key])
        else: numLeafs+=1
    return numLeafs
 
#计算树的最大深度
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=myTree.keys()[0]
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth=1+getTreeDepth(secondDict[key])
        else: thisDepth=1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth
 
#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\
    xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\
    bbox=nodeType,arrowprops=arrow_args)
 
#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):
    lens=len(txtString)
    xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002
    yMid=(parentPt[1]+cntrPt[1])/2.0
    createPlot.ax1.text(xMid,yMid,txtString)
    
def plotTree(myTree,parentPt,nodeTxt):
    numLeafs=getNumLeafs(myTree)
    depth=getTreeDepth(myTree)
    firstStr=myTree.keys()[0]
    cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict=myTree[firstStr]
    plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalW
            plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)
            plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))
    plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalD
 
def createPlot(inTree):
    fig=plt.figure(1,facecolor='white')
    fig.clf()
    axprops=dict(xticks=[],yticks=[])
    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW=float(getNumLeafs(inTree))
    plotTree.totalD=float(getTreeDepth(inTree))
    plotTree.x0ff=-0.5/plotTree.totalW
    plotTree.y0ff=1.0
    plotTree(inTree,(0.5,1.0),'')
    plt.show()

yanzheng.py

import numpy as np

def  one(myTree,data,labels):
	if type(myTree) == int:
		return myTree
	key = myTree.keys()[0]
	keys = key.split('<=')
	if data[labels == keys[0]] <= float(keys[1]):
		return one(myTree[myTree.keys()[0]][1],data,labels)
	else:
		return one(myTree[myTree.keys()[0]][0],data,labels)

def getResult(myTree,data,labels):
	result = []
	for elem in data:
		result.append(one(myTree,elem,labels))
	return result

def yanzheng(myTree,data,labels,target):
	count = 0
	result = getResult(myTree,data,labels)
	for i in range(len(result)):
		if(result[i] == target[i]):
			count += 1
	return count
结果

结果见:
https://blog.csdn.net/weixin_38273255/article/details/88981203

  • 3
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值