西瓜书习题4.3(基于信息熵的决策树)

试编程实现基于信息熵进行划分选择的决策树算法,并为表4.3中数据生成一颗决策树。

代码

import numpy as np
import matplotlib.pyplot as plt
from pylab import *

# 特征字典,后面用到了好多次,干脆当全局变量了
featureDic = {
    '色泽': ['浅白', '青绿', '乌黑'],
    '根蒂': ['硬挺', '蜷缩', '稍蜷'],
    '敲声': ['沉闷', '浊响', '清脆'],
    '纹理': ['清晰', '模糊', '稍糊'],
    '脐部': ['凹陷', '平坦', '稍凹'],
    '触感': ['硬滑', '软粘']}


def getDataSet():
    """
    get watermelon data set 3.0 alpha.
    :return: 编码好的数据集以及特征的字典。
    """
    dataSet = [
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.697, 0.460, 1],
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.774, 0.376, 1],
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.634, 0.264, 1],
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 0.608, 0.318, 1],
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.556, 0.215, 1],
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.403, 0.237, 1],
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 0.481, 0.149, 1],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 0.437, 0.211, 1],
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 0.666, 0.091, 0],
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 0.243, 0.267, 0],
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 0.245, 0.057, 0],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 0.343, 0.099, 0],
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 0.639, 0.161, 0],
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 0.657, 0.198, 0],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 0.360, 0.370, 0],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 0.593, 0.042, 0],
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 0.719, 0.103, 0]
    ]

    features = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖量']

    # #得到特征值字典,本来用这个生成的特征字典,还是直接当全局变量方便
    # featureDic = {}
    # for i in range(len(features)):
    #     featureList = [example[i] for example in dataSet]
    #     uniqueFeature = list(set(featureList))
    #     featureDic[features[i]] = uniqueFeature

    # 每种特征的属性个数
    numList = []  # [3, 3, 3, 3, 3, 2]
    for i in range(len(features) - 2):
        numList.append(len(featureDic[features[i]]))

    # 编码,把文字替换成数字。用1、2、3表示同种特征的不同类型
    newDataSet = []
    for dataVec in dataSet:  # 第一每一个数据
        dataNum = dataVec[-3:]  # 保存数据中的数值部分
        newData = []
        for i in range(len(dataVec) - 3):  # 值为字符的每一列
            for j in range(numList[i]):  # 对应列的特征的每一类
                if dataVec[i] == featureDic[features[i]][j]:
                    newData.append(j+1)
        newData.extend(dataNum)  # 编码好的部分和原来的数值部分合并
        newDataSet.append(newData)

    return np.array(newDataSet), features


# # test getDataSet()
# newData, features = getDataSet()
# print(newData)
# print(features)

def calEntropy(dataArr, classArr):
    """
    calculate information entropy.
    :param dataArr:
    :param classArr:
    :return: entropy
    """

    n = dataArr.size
    data0 = dataArr[classArr == 0]
    data1 = dataArr[classArr == 1]
    p0 = data0.size / float(n)
    p1 = data1.size / float(n)
    # 约定:p=0, p*log_2(p) = 0
    if p0 == 0:
        ent = -(p1 * np.log(p1))
    elif p1 == 0:
        ent = -(p0 * np.log(p0))
    else:
        ent = -(p0 * np.log2(p0) + p1 * np.log2(p1))

    return ent

# # test calEntropy()
# dataSet, _ = getDataSet()
# print(calEntropy(dataSet[:, :-1], dataSet[:, -1]))


def splitDataSet(dataSet, ax, value):
    """
    按照给点的属性ax和其中一种取值value来划分数据。
    当属性类型为标称数据时,返回一个属性值都为value的数据集。
    当属性类型为数值型数据事,以与value的大小关系为基准返回两个数据集。

    input:
        dataSet: 输入数据集,形状为(m,n)表示m个数据,前n-1列个属性,最后一列为类型。
        ax:属性类型
        value: 标称型时为1、2、3等。数值型为形如0.123的数。

    return:
        1.标称型dataSet返回第ax个属性中值为value组成的集合
        2.数值型dataSet返回两个集合。其一中数据都小于等于value,另一都大于。
    """
    # 2个连续属性密度、含糖量+类型为后3列,其余为标称型
    if ax < dataSet.shape[1] - 3:
        dataS = np.delete(dataSet[dataSet[:, ax] == value], ax, axis=1)
        return dataS
    else:
        dataL = dataSet[dataSet[:, ax] <= value]
        dataR = dataSet[dataSet[:, ax] > value]
        return dataL, dataR


# # test splitDataSet()
# dataSet, _ = getDataSet()
# test1 = splitDataSet(dataSet, 3, 1)
# test2L, test2R = splitDataSet(dataSet, 6, 0.5)
# print("test1 = ", test1)
# print("test2L = ", test2L)
# print("test2R = ", test2R)


def calInfoGain(dataSet, labelList, ax, value=-1):
    """
    计算给定数据dataSet在属性ax上的香农熵增益。

    input:
        dataSet:输入数据集,形状为(m,n)表示m个数据,前n-1列个属性,最后一列为类型。
        labelList:属性列表,如['色泽', '根蒂', '敲声', '纹理', '脐部', '触感', '密度', '含糖量']
        ax: 选择用来计算信息增益的属性。0表示第一个属性,1表示第二个属性等。
                    前六个特征是标称型,后两个特征是数值型。
        value: 用来划分数据的值。当标称型时默认为-1, 即不使用这个参数。

    return:
        gain:信息增益
    """
    baseEnt = calEntropy(dataSet[:, :-1], dataSet[:, -1])  # 计算D的原始信息熵

    newEnt = 0.0  # 划分完数据后的香农熵
    if ax < dataSet.shape[1] - 3:  # 计算标称型的香农熵
        num = len(featureDic[labelList[ax]])   # 每一个特征的类别数
        for j in range(num):
            subDataSet = splitDataSet(dataSet, ax, j+1)
            prob = len(subDataSet) / float(len(dataSet))
            if prob != 0:
                newEnt += prob * calEntropy(subDataSet[:, :-1], subDataSet[:, -1])
    else:
        # 数据集划分为两份
        dataL, dataR = splitDataSet(dataSet, ax, value)
        # 计算两数据集的信息熵
        entL = calEntropy(dataL[:, :-1], dataL[:, -1])
        entR = calEntropy(dataR[:, :-1], dataR[:, -1])
        # 计算划分完总数据集的信息熵
        newEnt = (dataL.size * entL + dataR.size * entR) / float(dataSet.size)

    # 计算信息增益
    gain = baseEnt - newEnt
    return gain


# # test calInfoGain(dataSet, featureDic, axis, value=-1):
# data, feat = getDataSet()
# print(calInfoGain(data, feat, 2))


def chooseBestSplit(dataSet, labelList):
    """
    计算信息增益增大的划分数据集的方式. 当返回的不是数值型特征时, 划分值bestThresh = -1
    input:
        dataSet
        labelList
    return:
        bestFeature: 使得到最大增益划分的属性。
        bestThresh: 使得到最大增益划分的数值。标称型时无意义令其为-1。
        maxGain:    最大增益划分时的增益值。
    """
    maxGain = 0.0
    bestFeature = -1
    bestThresh = -1
    m, n = dataSet.shape
    # 对每一个特征
    for i in range(n - 1):
        if i < (n - 3):     # 标称型
            gain = calInfoGain(dataSet, labelList, i)
            if gain > maxGain:
                bestFeature = i
                maxGain = gain
        else:   # 数值型
            featVals = dataSet[:, i]  # 得到第i个特征的所有值
            sortedFeat = np.sort(featVals)  # 按照从小到大的顺序排列第i个特征的所有值
            T = []
            # 计算划分点
            for j in range(m - 1):
                t = (sortedFeat[j] + sortedFeat[j + 1]) / 2.0
                T.append(t)
            # 对每一个划分值,计算增益熵
            for t in T:
                gain = calInfoGain(dataSet, featureDic, i, t)
                if gain > maxGain:
                    bestFeature = i
                    bestThresh = t
                    maxGain = gain

    return bestFeature, bestThresh, maxGain


# # test chooseBestSplit
# data, feat = getDataSet()
# f, tv, g = chooseBestSplit(data, feat)
# print(f"best feature is {list(featureDic.keys())[f]}\n"
#       f"best thresh value is {tv}\n"
#       f"max information gain is {g}")


def majorityCnt(classList):
    """
    投票,0多返回"坏瓜",否则返回"坏瓜"。
    """
    cnt0 = len(classList[classList == 0])
    cnt1 = len(classList[classList == 1])
    if cnt0 > cnt1:
        return '坏瓜'
    else:
        return '好瓜'


def createTree(dataSet, labels):
    """
    通过信息增益递归创造一颗决策树。
    input:
        labels
        dataSet

    return:
        myTree: 返回一个存有树的字典
    """
    classList = dataSet[:, -1]
    # 如果剩余的类别全相同,则返回
    if len(classList[classList == classList[0]]) == len(classList):
        if classList[0] == 0:
            return '坏瓜'
        else:
            return '好瓜'
    # 如果只剩下类标签,投票返回
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    # 得到增益最大划分的属性、值
    bestFeat, bestVal, entGain = chooseBestSplit(dataSet, labels)
    bestFeatLabel = labels[bestFeat]

    if bestVal != -1:  # 如果是数值型
        txt = bestFeatLabel + "<=" + str(bestVal) + "?"
    else:   # 如果是标称型
        txt = bestFeatLabel + "=" + "?"

    myTree = {txt: {}}  # 创建字典,即树的节点。
    if bestVal != -1:   # 数值型的话就是左右两个子树。
        subDataL, subDataR = splitDataSet(dataSet, bestFeat, bestVal)
        myTree[txt]['是'] = createTree(subDataL, labels)
        myTree[txt]['否'] = createTree(subDataR, labels)
    else:
        i = 0
        # 生成子树的时候要将已遍历的属性删去。数值型不要删除。
        del (labels[bestFeat])
        uniqueVals = featureDic[bestFeatLabel]  # 最好的特征的类别列表
        for value in uniqueVals:    # 标称型的属性值有几种,就要几个子树。
            # Python中列表作为参数类型时,是按照引用传递的,要保证同一节点的子节点能有相同的参数。
            subLabels = labels[:]  # subLabels = 注意要用[:],不然还是引用
            i += 1
            subDataSet = splitDataSet(dataSet, bestFeat, i)
            myTree[txt][value] = createTree(subDataSet, subLabels)

    return myTree


# # test createTree()
# data, feat = getDataSet()
# Tree = createTree(data, feat)
# print(Tree)


# ***********************画图***********************
# **********************start***********************
# 详情参见机器学习实战决策树那一章

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
mpl.rcParams['font.sans-serif'] = ['SimHei']  # 没有这句话汉字都是口口
# mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像是负号'-'显示为方块的问题


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, fontsize=20)


def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 绘制带箭头的注解
    createPlot.ax1.annotate(nodeTxt,
                            xy=parentPt,
                            xycoords="axes fraction",
                            xytext=centerPt,
                            textcoords="axes fraction",
                            va="center",
                            ha="center",
                            bbox=nodeType,
                            arrowprops=arrow_args,
                            fontsize=20)


def getNumLeafs(myTree):  # 获取叶节点的数目
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):  # 获取树的层数
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
                     cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, figsize=(600, 30), facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


# ***********************画图***********************
# ***********************end************************


def main():
    dataSet, labelList = getDataSet()
    myTree = createTree(dataSet, labelList)
    createPlot(myTree)


if __name__ == '__main__':
    main()

画图

在这里插入图片描述

补充

一开始做错了,用的只有密度和含糖率的那个数据,定义了个二叉树节点的类,结果改了半天机器学习实战那个画树的代码,美滋滋完发现不是那回事。。。节点并不是两个,就当写成CART了吧。又改成了用字典保存树的代码,之前写的就也保存一下。

代码

import numpy as np
import matplotlib.pyplot as plt
from pylab import *


# define tree node
class TreeNode:
    def __init__(self, feature, thresh):
        self.feature = feature  # 特征:密度 or 含糖率
        self.thresh = thresh  # 基于某个特征分类时的划分值
        self.label = -1  # 类别:只在叶子节点上不为-1,为0或1代表好瓜的否和是。
        self.data = []  # 用来保存该节点上的数据
        self.left = None  # 左右结点
        self.right = None

    def numOfGood(self):
        return self.data[self.data == 1].sum()

    def numOfBad(self):
        return self.data[self.data == 0].sum()


def getDataSet():
    """
    get watermelon data set 3.0 alpha.
    :return: (feature array, label array)
    """
    dataSet = np.array([
        [0.697, 0.460, 1],
        [0.774, 0.376, 1],
        [0.634, 0.264, 1],
        [0.608, 0.318, 1],
        [0.556, 0.215, 1],
        [0.403, 0.237, 1],
        [0.481, 0.149, 1],
        [0.437, 0.211, 1],
        [0.666, 0.091, 0],
        [0.243, 0.267, 0],
        [0.245, 0.057, 0],
        [0.343, 0.099, 0],
        [0.639, 0.161, 0],
        [0.657, 0.198, 0],
        [0.360, 0.370, 0],
        [0.593, 0.042, 0],
        [0.719, 0.103, 0]
    ])

    return dataSet


def calEntropy(dataArr, labelArr):
    """
    calculate information entropy.
    :param dataArr:
    :return: entropy
    """
    n = dataArr.size
    data0 = dataArr[labelArr == 0]
    data1 = dataArr[labelArr == 1]
    p0 = data0.size / float(n)
    p1 = data1.size / float(n)
    # 约定:p=0, p*log_2(p) = 0
    if p0 == 0:
        entropy = -(p1 * np.log(p1))
    elif p1 == 0:
        entropy = -(p0 * np.log(p0))
    else:
        entropy = -(p0 * np.log(p0) + p1 * np.log(p1))

    return entropy


# # test calEntropy()
# dataSet = getDataSet()
# print(calEntropy(dataSet[:, :-1], dataSet[:, -1],))


def calInfoGain(dataSet, feature, thresh):
    """
    calculate information gain
    :param dataSet: 数据集,最后一列为类别。
    :param feature: 选择用来计算信息增益的特征。0表示第一个特征,1表示第二个特征
    :param thresh: 用来划分数据的值
    :return: 信息增益
    """
    entD = calEntropy(dataSet[:, :-1], dataSet[:, -1])  # 计算D的原始信息熵
    # 数据集划分为两份
    dataL = dataSet[dataSet[:, feature] <= thresh]
    dataR = dataSet[dataSet[:, feature] > thresh]
    # 计算两数据集的信息熵
    entL = calEntropy(dataL[:, :-1], dataL[:, -1])
    entR = calEntropy(dataR[:, :-1], dataR[:, -1])
    # 计算划分完总数据集的信息熵
    entDS = (dataL.size * entL + dataR.size * entR) / float(dataSet.size)
    # 计算信息增益
    gain = entD - entDS
    return gain


# # test calInfoGain(dataSet, feature, thresh)
# data = getDataSet()
# print(calInfoGain(data, 0, 0.6))


def chooseBestSplit(dataSet):
    """
    计算信息增益增大的划分数据集的方式
    :param dataSet:
    :return: 信息增益最大的划分方式的 特征 和 划分值。
    """
    maxGain = 0.0
    bestFeature = -1
    bestThresh = -1
    m, n = dataSet.shape
    # 对每一个特征
    for i in range(n - 1):
        feat = dataSet[:, i]  # 得到第i个特征的所有值
        sortedFeat = np.sort(feat)  # 按照从小到大的顺序排列第i个特征的所有值
        T = []
        # 计算划分点
        for j in range(m - 1):
            t = (sortedFeat[j] + sortedFeat[j + 1]) / 2.0
            T.append(t)
        # 对每一个划分值,计算增益熵
        for val in T:
            gain = calInfoGain(dataSet, i, val)
            if gain > maxGain:
                bestFeature = i
                bestThresh = val
                maxGain = gain

    return bestFeature, bestThresh, maxGain


# # test chooseBestSplit
# data = getDataSet()
# f, tv, g = chooseBestSplit(data)
# print(f"best feature is {f}\n"
#       f"best thresh value is {tv}\n"
#       f"max information gain is {g}")

def createTree(dataSet):
    """
    通过信息增益创造一颗决策树
    :param dataSet:
    :return: 返回一颗树的根结点
    """
    # 到叶子节点时返回。
    # 若只剩k个相同类的数据,信息熵 = -(0*log_2(0) + k*log_2(k) = 0
    # 即信息熵为0时返回叶子结点
    if calEntropy(dataSet[:, :-1], dataSet[:, -1]) == 0:
        leaf = TreeNode(-1, -1)  # 构造叶子结点
        leaf.label = dataSet[0][-1]
        return leaf

    feature, thresh, gain = chooseBestSplit(dataSet)
    dataL = dataSet[dataSet[:, feature] <= thresh]
    dataR = dataSet[dataSet[:, feature] > thresh]

    Node = TreeNode(feature, thresh)
    Node.data = dataSet
    Node.left = createTree(dataL)
    Node.right = createTree(dataR)

    return Node


# # test createTree()
# data = getDataSet()
# Tree = createTree(data)


# ***********************画图***********************
# **********************start***********************
def getNumLeafs(myTree):
    """
    得到叶子结点的数量
    :param myTree:
    :return:
    """
    if myTree.feature == -1:
        return 1
    if myTree is None:
        return 0

    return getNumLeafs(myTree.left) + getNumLeafs(myTree.right)


# # test getNumLeafs()
# data = getDataSet()
# Tree = createTree(data)
# print(getNumLeafs(Tree))  # 5个叶子


def getTreeDepth(myTree):
    """
    得到树的深度
    :param myTree:
    :return:
    """
    if myTree is None:
        return 0

    # 1表示加上当前节点
    depth = max(1 + getTreeDepth(myTree.left),
                1 + getTreeDepth(myTree.right))
    return depth


# # test getTreeDepth()
# data = getDataSet()
# Tree = createTree(data)
# print(getTreeDepth(Tree))  # 深度为5


# 没有这句的话画出的图上面汉字会显示成口口
mpl.rcParams['font.sans-serif'] = ['SimHei']

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

fList = ["密度", "含糖率"]  # 后面画节点时用到
melon = ["坏瓜", "好瓜"]


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制一个结点以及指向这个减点的箭头
    :param nodeTxt: 结点上的文字
    :param centerPt: 箭头终止坐标
    :param parentPt: 箭头起始坐标,即对应树中父结点坐标。即从parentPt指向centerPt
    :param nodeType: 结点类型。实际上是一个字典,里面保存着绘制结点的参数,
                     decisionNode:表示非叶子结点。leafNode表示叶子结点、
    :return:
    """
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,
                            xycoords='axes fraction',
                            xytext=centerPt,
                            textcoords='axes fraction',
                            va="center", ha="center",
                            bbox=nodeType,
                            arrowprops=arrow_args,
                            fontsize=15)  # 结点字的大小


# def createPlot():
#     fig = plt.figure(1, facecolor='white')
#     fig.clf()
#     createPlot.ax1 = plt.subplot(111, frameon=False)
#     plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
#     plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
#     plt.show()

# createPlot()


def plotMidText(cntrPt, parentPt, txtString):
    """
    计算父节点和子节点中间的位置,即箭头中间的位置上画上文本,比如"是"和"否"
    :param cntrPt: 子节点的坐标
    :param parentPt:父节点的坐标
    :param txtString: 要画的字符
    :return:
    """
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString,
                        va="center", ha="center", rotation=30,
                        fontsize=15)


def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on
    """
    递归画树
    :param myTree: 树节点
    :param parentPt: 父节点坐标
    :param nodeTxt: 节点字符
    :return:
    """
    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree
    depth = getTreeDepth(myTree)
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    if myTree.thresh != -1:
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(str(fList[myTree.feature])  # fList = ["密度", "含糖率"]
                 + "<=" + str(myTree.thresh) + "?",
                 cntrPt, parentPt, decisionNode)
        plotTree.yOff = plotTree.yOff - 1 / plotTree.totalD
    else:
        plotTree.xOff = plotTree.xOff + 1 / plotTree.totalW
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(melon[int(myTree.label)], cntrPt, parentPt, decisionNode)  # melon = ["坏瓜", "好瓜"]
        plotTree.yOff = plotTree.yOff - 1 / plotTree.totalD

    if myTree.left is not None:
        plotTree(myTree.left, cntrPt, "是")
    if myTree.right is not None:
        plotTree(myTree.right, cntrPt, "否")

    plotTree.yOff = plotTree.yOff + 1 / plotTree.totalD


def createPlot(inTree):
    """
    设置画图的基本信息,如树的宽度和深度,初始坐标等。调用plotTree()画图
    :param inTree:
    :return:
    """
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks
    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


# ***********************画图***********************
# ***********************end************************


def main():
    data = getDataSet()
    Tree = createTree(data)
    createPlot(Tree)


if __name__ == '__main__':
    main()

画图

在这里插入图片描述

  • 18
    点赞
  • 98
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值