机器学习实战-第三章 决策树算法

from math import log

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    # print("numEntries:", numEntries)
    labelCounts = {}
    count = 0
    for featVec in dataSet:
        # count += 1
        # print("第 %d次:" % count)
        currentLabel = featVec[-1]
        # print("currentLabel:", currentLabel)
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
        shannonEnt = 0.0
    # 根据标签做个统计,每个键值都记录了当前类别出现的次数。
    # print("labelCounts:", labelCounts)
    count = 0
    for key in labelCounts:
        # count += 1
        # print("第 %s类:" % key)
        prob = float(labelCounts[key])/numEntries
        # print("prob:", prob)
        shannonEnt -= prob * log(prob, 2)

    return shannonEnt

def createDataSet():
    dataSet = [[1, 1, "yes"],
               [1, 1, "yes"],
               [1, 0, "no"],
               [0, 1, "no"],
               [0, 1, "no"]]
    labels = ["no surfacing", "flippers"]

    return dataSet, labels

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    count = 0
    for feavVec in dataSet:
        # count += 1
        # print("第 %d次:" % count)
        # print("feavVec:", feavVec)
        # print("feavVec[axis]:", feavVec[axis])
        # print("feavVec[:axis]:", feavVec[:axis])
        if feavVec[axis] == value:
            reduceFeatVec = feavVec[:axis]
            # print("reduceFeatVec:", reduceFeatVec)
            # print("feavVec[axis+1:]:", feavVec[axis+1:])
            reduceFeatVec.extend(feavVec[axis+1:])

            retDataSet.append(reduceFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    # print("dataSet:", dataSet)
    # print("dataSet[0]:", dataSet[0])
    numFeatures = len(dataSet[0]) - 1
    # print("numFeatures:", numFeatures)
    baseEntroy = calcShannonEnt(dataSet)# 计算给定数据集的熵。
    # print("baseEntroy:", baseEntroy)

    bestInfoGain = 0.0
    bestFeature = -1

    for i in range(numFeatures):
        # print("第 %d次:" % i)
        featList = [example[i] for example in dataSet]
        # print("featList:", featList)
        uniqueVals = set(featList)# 放在集合中,去重
        # print("uniqueVals:", uniqueVals)
        newEntropy = 0.0

        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntroy - newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


import operator
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    print("classList:", classList)

    if classList.count(classList[0]) == len(classList):
        return classList[0]


    if len(dataSet[0]) == 1:
        return majorityCnt(classList)


    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

import matplotlib.pyplot as plt

# 定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


# 绘制箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.axl.annotate(nodeTxt, centerPt, parentPt, nodeType, xycoords = "axes fraction",
                            xytext = centerPt, textcoords = "axes fraction", va = "center",
                            ha = "center", bbox = "nodeType", arrowprops = arrow_args)
def createPlot():
      fig = plt.figure(1, facecolor="white")
      fig.clf()
      createPlot.ax1 = plt.subplot(111, frameon=False)
      plotNode("决策节点", (0.5, 0.1), (0.1, 0.5), decisionNode)
      plotNode("叶节点", (0.8, 0.1), (0.3, 0.8), leafNode)
      plt.show()


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth += 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def retrieveTree(i):
    listOfTrees = [{"no surfacing": {0: "no", 1: {"flippers":{0: "no", 1 :"yes"}}}},
                   {"no surfacing": {0: "no", 1:{"flippers":{0: {"head": {0: "no", 1: "yes"}}, 1: "no"}}}}]





if __name__ == "__main__":
    myDat, labels = createDataSet()
    print("myDat:", myDat)

    print("calcShannonEnt(myDat):", calcShannonEnt(myDat))
    print("----------------增加类别---------------------------------")
    myDat[0][-1] = "maybe"
    print("myDat:", myDat)
    print("calcShannonEnt(myDat):", calcShannonEnt(myDat))
    print("----------------splitDataSet---------------------------------")
    a = [1,2,3]
    b=[4,5,6]
    a.append(b)
    print("a:", a)
    a = [1,2,3]
    a.extend(b)
    print("a:", a)

    myDat, labels = createDataSet()
    print("myDat:", myDat)
    temp = splitDataSet(myDat, 0, 0)
    print("temp(0,0):", temp)
    print("----------------splitDataSet---------------------------------")
    temp = splitDataSet(myDat, 0, 1)
    print("temp(0,1):", temp)
    print("----------------chooseBestFeatureToSplit---------------------------------")
    myDat, labels = createDataSet()
    print("myDat:", myDat)
    temp = chooseBestFeatureToSplit(myDat)
    print("temp:", temp)

    # dataSet = myDat
    # featList = [example[1] for example in dataSet]
    # print("featList---:", featList)
    # for example in dataSet:
    #     print("example:", example)
    print("----------------createTree---------------------------------")

    myDat, labels = createDataSet()
    myTree = createTree(myDat, labels)
    print("myTree:", myTree)
    print("----------------createPlot---------------------------------")
    # createPlot()
    




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值