构建决策树(ID3算法)

#决策树(ID3算法)
from math import log
import operator
import json
import treePlotter
#创建数据集
def create_dataset():
    data = open("E://data.txt")
    dataset = []
    for line in data.readlines():
        curLine = line.split()
        curLine = list(map(int,curLine))
        dataset.append(curLine)
    labels = ['Age','Number','Time','Pressure','Problem']
    return dataset, labels

#计算信息熵
def calc_shangEnt(dataset):
    #创建集合
    num = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        curlabel = featVec[-1]
        if curlabel not in labelCounts.keys():
            labelCounts[curlabel] = 0
        labelCounts[curlabel] += 1
    #计算信息熵
    shannonEnt = 0
    for key in labelCounts:
        prob = float(labelCounts[key])/num
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

#划分数据集,axis是第几列元素,value是你想要的该列元素的取值
def splitdataset(dataset,axis,value):
    retdataset = []
    featVec = []
    for featVec in dataset:
        if featVec[axis] == value:
            reducefeatVec = featVec[:axis]#a[:n]:表示从第0个元素到第n个元素,不包括n
            reducefeatVec.extend(featVec[axis+1:])#a[1:]:表示从第1个元素到最后一个元素
            retdataset.append(reducefeatVec)
    return retdataset

#计算条件熵
def bestfeature(dataset):
    numfeature=len(dataset[0])-1
    baseEntropy = calc_shangEnt(dataset)
    bestInfoGain = 0.0
    bestfeature = -1
    newprob=0.0
    for i in range(numfeature):
        featlist=[example[i] for example in dataset]
        uniqueVals = set(featlist)
        for value in uniqueVals:
            subdataset = splitdataset(dataset,i,uniqueVals)
            prob=len(subdataset)/float(len(dataset))
            newprob += prob * calc_shangEnt(subdataset)
        inInfoGain = baseEntropy - newprob
        if (inInfoGain > bestInfoGain):
            bestInfoGain = inInfoGain
            bestfeature=i
    return bestfeature

#递归构建决策树
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

#递归构建决策树
def creattree(dataset,labels):
    classList=[example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    bestfeat = bestfeature(dataset)
    bestfeatlabel = labels[bestfeat]
    myTree = {bestfeatlabel:{}}
    del (labels[bestfeat])
    featvalues = [example[bestfeat] for example in dataset]
    uniqueValue = set(featvalues)
    for value in uniqueValue:
        sublabels=labels[:]
        myTree[bestfeatlabel][value] = creattree(splitdataset(dataset,bestfeat,value),sublabels)
    return myTree

if __name__ == '__main__':
    dataset, labels = create_dataset()
    Trees = creattree(dataset,labels)
    treePlotter.createPlot(Trees)
    print(json.dumps(Trees,ensure_ascii=False))


import matplotlib.pyplot as plt

# 设置决策节点和叶节点的边框形状、边距和透明度,以及箭头的形状
decisionNode = dict(boxstyle="square,pad=0.05",fc="2")
leafNode = dict(boxstyle="round4,pad=0.05", fc="2")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType): # centerPt节点中心坐标  parentPt 起点坐标
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,
                            xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="top", ha="center", bbox=nodeType,
                            arrowprops=arrow_args)

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')#创建一个新图形
    fig.clf() #清空绘图区
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False,**axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid,txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
                     cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值