【python】ID3

节选自《Machine Learning in Action》——Peter Harrington
中文版是《机器学习实战》
本文介绍的是ID3算法:通过最大化节点划分前后熵的差值



ID3 优缺点:

  • 优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
  • 缺点:可能会产生过度匹配的问题

决策树的一般流程

  1. 收集数据:anyway
  2. 准备数据:离散化
  3. 分析数据:anyway,构造数完成后,检查树是否符合预期
  4. 训练算法:构造树的数据结构
  5. 测试算法:使用经验树计算错误率
  6. 使用算法:适用于任何监督学习算法,而使用决策树可以更好的理解数据的内在含义

0 Entropy

label 有 C C C 类,第 i i i 类的概率为 p i p_i pi
H = − p i ⋅ ∑ i = 1 C l o g ( p i ) H = -p_i\cdot \sum_{i=1}^{C}log(p_{i}) H=pii=1Clog(pi)

下面是 C =2 的情况

import numpy as np
import matplotlib.pyplot as plt
def entropy(p):
    return -p * np.log(p) - (1-p) * np.log(1-p)
x = np.linspace(0.01, 0.99, 200)
plt.plot(x, entropy(x))
plt.show()

在这里插入图片描述
随便看看 deep learning 里面交叉熵的损失

import numpy as np
import matplotlib.pyplot as plt

def entropy1(p):
    return -np.log(p)

def entropy2(p):
    return -np.log(1-p)

x = np.linspace(0.01, 0.99, 200)
plt.plot(x, entropy1(x))
plt.plot(x, entropy2(x))
plt.legend(['entropy1,y=1','entropy2,y=0'])
plt.show()

在这里插入图片描述

1 Decision Tree

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data[:,2:] # 150,2
y = iris.target # 150,(0,1,2) 3 class

plt.scatter(X[y==0,0], X[y==0,1])
plt.scatter(X[y==1,0], X[y==1,1])
plt.scatter(X[y==2,0], X[y==2,1])
plt.show()

from sklearn.tree import DecisionTreeClassifier
# random_state 固定随机种子
dt_clf = DecisionTreeClassifier(max_depth=2, criterion="entropy", random_state=42)
dt_clf.fit(X, y)

在这里插入图片描述
可视化

#ravel():如果没有必要,不会产生源数据的副本
#flatten():返回源数据的副本

# np.r_是按列连接两个矩阵,就是把两矩阵上下相加,要求列数相等。
# np.c_是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等。

def plot_decision_boundary(model, axis):
    
    x0, x1 = np.meshgrid( #生成网格点坐标矩阵。
        np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1, 1), # 0.5-7.5 之间产生 700个点
        np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1, 1), # 0-3 之间产生 300个点
    )
    # x0 300,700
    # x1 300,700
    X_new = np.c_[x0.ravel(), x1.ravel()] # (210000, 2)

    y_predict = model.predict(X_new) # (210000,)
    zz = y_predict.reshape(x0.shape) # 预测结果变成 300,700

    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
    
    plt.contourf(x0, x1, zz, cmap=custom_cmap)

调用查看结果

plot_decision_boundary(dt_clf, axis=[0.5, 7.5, 0, 3])
plt.scatter(X[y==0,0], X[y==0,1])
plt.scatter(X[y==1,0], X[y==1,1])
plt.scatter(X[y==2,0], X[y==2,1])
plt.show()

在这里插入图片描述
参考:

使用信息熵划分

from collections import Counter
from math import log

# 根据特征值v 把 X 和 y 的属性 d 划分开
def split(X, y, d, value):
    index_a = (X[:,d] <= value) # d 列特征值小于等于某个 value
    index_b = (X[:,d] > value)  # d 列特征值大于某个 value
    return X[index_a], X[index_b], y[index_a], y[index_b]

# 计算信息熵
def entropy(y):
    counter = Counter(y)
    res = 0.0
    for num in counter.values():
        p = num / len(y)
        res += -p * log(p)
    return res

# 根据信息熵,找出最优划分特征和最优划分特征值
def try_split(X, y):
    best_entropy = float('inf')
    best_d, best_v = -1, -1
    for d in range(X.shape[1]): # 遍历每个特征, 0 and 1
        sorted_index = np.argsort(X[:,d]) # 按照特征值大小排序,返回索引
        for i in range(1, len(X)): # 1,150 遍历每个样本
            if X[sorted_index[i], d] != X[sorted_index[i-1], d]: # 当前和前面一个的特征值不同时
                v = (X[sorted_index[i], d] + X[sorted_index[i-1], d])/2 # current and previous ave
                X_l, X_r, y_l, y_r = split(X, y, d, v) # 根据特征值v 把 X 和 y 的属性 d 划分开
                e = entropy(y_l) + entropy(y_r)
                if e < best_entropy: # 熵值是越小越好
                    best_entropy, best_d, best_v = e, d, v # 返回最低的熵,划分后最好的特征,最好的特征划分值
                
    return best_entropy, best_d, best_v

模拟一下,第一次

# X(150,2)
# y(150,)
# 计算第一次划分的最优特征和最优特征值
best_entropy, best_d, best_v = try_split(X, y)
print("best_entropy =", best_entropy)
print("best_d =", best_d)
print("best_v =", best_v)

# 将数据分开
X1_l, X1_r, y1_l, y1_r = split(X, y, best_d, best_v)
print("y1_l:",entropy(y1_l))
print("y1_r:",entropy(y1_r))

output

best_entropy = 0.6931471805599453
best_d = 0
best_v = 2.45
y1_l: 0.0
y1_r: 0.6931471805599453

在前面的基础上 once more

# 在第一次的基础上计算第二次划分的最优特征和最优特征值
best_entropy2, best_d2, best_v2 = try_split(X1_r, y1_r)
print("best_entropy =", best_entropy2)
print("best_d =", best_d2)
print("best_v =", best_v2)

# 将数据分开
X2_l, X2_r, y2_l, y2_r = split(X1_r, y1_r, best_d2, best_v2)
print("y2_l:",entropy(y2_l))
print("y2_r:",entropy(y2_r))

output

best_entropy = 0.4132278899361904
best_d = 1
best_v = 1.75
y2_l: 0.30849545083110386
y2_r: 0.10473243910508653

2 Simple example

这里写图片描述

计算Gender特征的信息熵
这里写图片描述
E n t r o p y ( G e n d e r ) = 3 6 ( − 2 3 log ⁡ 2 3 − 1 3 log ⁡ 1 3 ) + 3 6 ( − 1 × l o g 1 ) ≈ 0.918 Entropy(Gender) = \frac{3}{6}(-\frac{2}{3}\log \frac{2}{3} -\frac{1}{3}\log \frac{1}{3})+\frac{3}{6}(-1\times log1 ) \approx 0.918 Entropy(Gender)=63(32log3231log31)+63(1×log1)0.918

计算Income特征的信息熵
这里写图片描述
E n t r o p y ( I n c o m e ) = 4 6 ( − 1 4 log ⁡ 1 4 − 3 4 log ⁡ 3 4 ) + 2 6 ( − 1 × l o g 1 ) ≈ 0.541 Entropy(Income) = \frac{4}{6}(-\frac{1}{4}\log \frac{1}{4} -\frac{3}{4}\log \frac{3}{4})+\frac{2}{6}(-1\times log1) \approx 0.541 Entropy(Income)=64(41log4143log43)+62(1×log1)0.541

计算Age特征的信息熵
这里写图片描述
E n t r o p y ( A g e ) = 3 6 ( − 2 3 log ⁡ 2 3 − 1 3 log ⁡ 1 3 ) + 2 6 ( − 1 × l o g 1 ) + 1 6 ( − 1 × l o g 1 ) ≈ 0.459 Entropy(Age) = \frac{3}{6}(-\frac{2}{3}\log \frac{2}{3} -\frac{1}{3}\log \frac{1}{3})+\frac{2}{6}(-1\times log1)+\frac{1}{6}(-1\times log1) \approx 0.459 Entropy(Age)=63(32log3231log31)+62(1×log1)+61(1×log1)0.459

所有样本的信息熵为:
I n f o r m a t i o n ( s a m p l e s ) = − 3 6 log ⁡ 3 6 − 3 6 log ⁡ 3 6 Information(samples) = -\frac{3}{6}\log \frac{3}{6}-\frac{3}{6}\log \frac{3}{6} Information(samples)=63log6363log63

信息增益定义为:
I n f o r m a t i o n G a i n ( x ) = I n f o r m a t i o n ( s a m p l e s ) − E n t r o p y ( x ) Information \quad Gain (x) = Information(samples) - Entropy(x) InformationGain(x)=Information(samples)Entropy(x)

对于每个特征,信息增益如下:

I G ( G e n d e r ) = I − E n t r o p y ( G e n d e r ) = 1 − 0.918 = 0.082 IG(Gender) = I - Entropy(Gender) = 1-0.918 = 0.082 IG(Gender)=IEntropy(Gender)=10.918=0.082

I G ( I n c o m e ) = I − E n t r o p y ( I n c o m e ) = 1 − 0.541 = 0.459 IG(Income) = I - Entropy(Income) = 1- 0.541 = 0.459 IG(Income)=IEntropy(Income)=10.541=0.459

I G ( A g e ) = I − E n t r o p y ( A g e ) = 1 − 0.459 = 0.541 IG(Age) = I - Entropy(Age) = 1- 0.459 = 0.541 IG(Age)=IEntropy(Age)=10.459=0.541

可以得出,样本按照Age属性划分,可以得到最大的信息增益,树如下
这里写图片描述

其中 按Youth和Age 属性划分后的样本属于同一类,不用继续划分下去,Middle属性划分后的样本不属于同一类,还需要继续划分
这里写图片描述

余下的样本信息熵为
I ( s a m p l e s ) = − 2 3 log ⁡ 2 3 − 1 3 log ⁡ 1 3 = 0.918 I(samples) = -\frac{2}{3}\log \frac{2}{3} -\frac{1}{3}\log \frac{1}{3} = 0.918 I(samples)=32log3231log31=0.918

按Gender特征划分后
这里写图片描述

计算Gender特征的信息熵

E n t r o p y ( G e n d e r ) = 2 3 ( − 1 × l o g 1 ) + 1 3 ( − 1 × l o g 1 ) = 0 Entropy(Gender) = \frac{2}{3}(-1\times log1)+\frac{1}{3}(-1\times log1 ) = 0 Entropy(Gender)=32(1×log1)+31(1×log1)=0

按Income特征划分后
这里写图片描述

计算Income特征的信息熵
E n t r o p y ( I n c o m e ) = 2 3 ( − 1 2 log ⁡ 1 2 − 1 2 log ⁡ 1 2 ) + 1 3 ( − 1 × l o g 1 ) = 0.667 Entropy(Income) = \frac{2}{3}(-\frac{1}{2}\log \frac{1}{2} -\frac{1}{2}\log \frac{1}{2})+\frac{1}{3}(-1\times log1 ) = 0.667 Entropy(Income)=32(21log2121log21)+31(1×log1)=0.667
由上述的Information Gain公式可知, IG(Gender) >IG(Income),所以 用Gender 特征来进一步划分样本
结果如下
这里写图片描述

所有叶节点的样本都属于同一类,决策结束

3 决策树的构造

3.1 信息增益

实验数据集如下:

构建数据集函数

# trees.py(1) 第一段代码,初始化数据集
from math import log
import operator
# 构建数据集
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']#没露出水面,脚蹼
    #change to discrete values
    return dataSet, labels

调用一下

from math import log
import trees #决策树代码,trees.py
import operator
#调用createDataSet()
myDat,labels = trees.createDataSet()  
print ('myDat:',myDat)
print ('labels:',labels)

结果为

myDat: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels: ['no surfacing', 'flippers']

计算信息增益
克劳德·香农写完信息论后,约翰·冯·诺依曼建议使用“熵”这个术语。“贝尔实验室和MIT有很多人将香农和爱因斯坦相提并论,而其他人则认为这种对比是不公平的——对香农是不公平的”
如果待分类的事务可能划分在多个分类之中,则某一类 x i x_{i} xi的信息定义如下:
l ( x i ) = − l o g 2 p ( x i ) l(x_{i})=-log_{2}^{p(x_{i})} l(xi)=log2p(xi)

其中 p ( x i ) p(x_{i}) p(xi)是选择分类的概率

为了计算熵,我们需要计算所有类别所有可能包含的信息期望值,通过下面公式得到:

H = − ∑ i = 1 n p ( x i ) l o g 2 p ( x i ) H = -\sum_{i=1}^{n}p(x_{i})log{_{2}^{p(x_{i})}} H=i=1np(xi)log2p(xi)

Note:熵值越高,系统越杂乱,混合的数据也越多!可以这样理解,事物总是趋向于混乱,也就是熵增。越纯,p越接近于1,log越接近于0,熵也就越小。

# trees.py(2) 第2段代码,计算样本的熵值
def calcShannonEnt(dataSet):
    numEntries = len(dataSet) # 5
    labelCounts = {} #为所有可能分类创建字典
    for featVec in dataSet: #[1, 1, 'yes'],[1, 1, 'yes']
        currentLabel = featVec[-1] #第一个样本的labes,eg: yes, no
        if currentLabel not in labelCounts.keys(): 
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    #labelCounts is {'yes': 2, 'no': 3}
    shannonEnt = 0.0
    for key in labelCounts:# key is yes or no
        prob = float(labelCounts[key])/numEntries #计算概率p,当前label的数量除以label的总数量
        shannonEnt -= prob * log(prob,2) #log base 2
    return shannonEnt

测试下

trees.calcShannonEnt(myDat)

结果为

0.9709505944546686

增加新的一类

myDat[0][-1]='maybe'#第一组最后一个属性,改为maybe
trees.calcShannonEnt(myDat)

结果为

1.3709505944546687

说明:熵值越高,系统越杂乱,混合的数据也越多

3.2 划分数据集

# trees.py(3) 第3段代码,按给定特征划分数据集
#第axis列选出来,与value对比,相等,输出除axis的列
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet: # featVec is [1, 1, 'yes'],[1, 1, 'yes']
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting,[0,axis)
            reducedFeatVec.extend(featVec[axis+1:])#[axis+1,last)
            retDataSet.append(reducedFeatVec)
    return retDataSet

测试下

myDat,labels = trees.createDataSet()
# DataSet,axis,val
print (trees.splitDataSet(myDat,0,1)) #第axis列选出来,与value对比,相等,输出除axis的列
print (trees.splitDataSet(myDat,0,0))
print (trees.splitDataSet(myDat,1,0))

结果

[[1, 'yes'], [1, 'yes'], [0, 'no']]
[[1, 'no'], [1, 'no']]
[[1, 'no']]

有了划分数据集的方法还不够,我们要选择最好的数据划分方式,思路为
遍历所有特征,统计每个特征下的属性种类,按照属性,调用数据划分函数划分数据集,然后计算划分后的熵,保留熵值最大的特征,作为bestFeature,注意输出结果是特征的序号,0代表第一个特征,1代表第二个特征。

具体实现如下:

# trees.py(4) 第4段代码,选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels,2
    baseEntropy = calcShannonEnt(dataSet) #0.9709505944546686
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #iterate over all the features
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        # 第一次循环为[1, 1, 1, 0, 0],五个样本的第一个特征
        uniqueVals = set(featList)       #get a set of unique values,变成了一个集合,{0,1}
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)#第i列特征值,与value比较,算出信息熵
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)     
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        #这个位置注意了,划分数据后,数据更有序,数据entropy变小了
        #0.4199730940219749 , 0.17095059445466854
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature                     #returns an integer

测试一下

trees.chooseBestFeatureToSplit(myDat)

结果为

0

3.3 递归构建决策树

有选取最优特征的方法后,我们就可以递归的构造决策树
递归结束的条件是

1. 程序遍历完所有划分数据集的属性,或者
2. 每个分支下的所有实例都具有相同的分类

上述第一种情况发生后,如果类标签依旧不是唯一的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决定该叶子节点的分类

# trees.py(5) 第5段代码,投票机制
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    # reverse = True 默认降序
    return sortedClassCount[0][0]

递归构建决策树

# trees.py(6) 第6段代码,构建决策树
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]   #['yes', 'yes', 'no', 'no', 'no']
    if classList.count(classList[0]) == len(classList): #if yes的数量等于列表的长度
        return classList[0]#stop splitting when all labels of the classes are equal,所有类标签一样
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        #只剩一个label了,特征都分光了
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet) #0
    bestFeatLabel = labels[bestFeat] #no surfacing
    myTree = {bestFeatLabel:{}}# {'no surfacing': {}}
    del(labels[bestFeat])# 剩下['flippers']
    featValues = [example[bestFeat] for example in dataSet] # [1, 1, 1, 0, 0]
    uniqueVals = set(featValues) # 变成集合
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree  

测试一下

myDat,labels = trees.createDataSet()  
myTree = trees.createTree(myDat,labels)
print (myTree)

结果为

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

4 在Python中使用Matplotlib注解绘制树形图

treePlotter.py实现如下

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

# 获取叶子节点的个数,以确定x的长度(根据解析字典结构来计算深度的)
def getNumLeafs(myTree):
    # myTree is {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    numLeafs = 0
    firstStr = list(myTree.keys())[0]# no surfacing
    #'dict_keys' object does not support indexing,python2与3的差别,加一个list()转换一下
    #keys()取出字典:的内容,firstStr是第一个节点,也就是根节点
    secondDict = myTree[firstStr]
    #{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
    for key in secondDict.keys(): # key is 0 or 1
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            # 这一句是最核心的代码
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

# 获取树的深度,以确定y的长度
def getTreeDepth(myTree):
    # myTree is {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    maxDepth = 0
    firstStr = list(myTree.keys())[0]# no surfacing
    #'dict_keys' object does not support indexing,python2与3的差别,加一个list()转换一下
    secondDict = myTree[firstStr]
    #{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: 
            maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
# 在父子节点之间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
# 计算宽与高
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

# 这个createPlot1才是核心的,createPlot只是一个demo
def createPlot1(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])# {'xticks': [], 'yticks': []}
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree)) #3.0
    plotTree.totalD = float(getTreeDepth(inTree)) #2.0
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
    
# 这是一个demo而已
def createPlot():
    fig = plt.figure(1, facecolor='white')# facecolor控制窗口背景色
    fig.clf()
    createPlot.ax1 = plt.subplot(1,1,1, frameon=False) #ticks for demo puropses #行列,第几个
    # frameon is True,就是图像与坐标轴之间有矩形边框,否则就是没有边框
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) 
    #第一个坐标是矩形的中心点坐标,第二个是剪头起始点的坐标
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()
# 把树的信息提前存储好了,以免每次测试代码的时候,
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#createPlot(thisTree)

4.1 Matplotlib注解

import matplotlib.pyplot as plt
import treePlotter
treePlotter.createPlot()

结果为
这里写图片描述

4.2 构造注解树

myDat,labels = trees.createDataSet()  
myTree = trees.createTree(myDat,labels)
print (myTree)
print ('number of leaves:',treePlotter.getNumLeafs(myTree))
print ('depth of the tree:',treePlotter.getTreeDepth(myTree))

结果为

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
number of leaves: 3
depth of the tree: 2

画出树

treePlotter.createPlot1(myTree)

这里写图片描述

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值