先准备数据,再准备分类器,max_depth参数就是决策树的深度。
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import tree,datasets
from sklearn.model_selection import train_test_split
wine = datasets.load_wine()
X = wine.data[:,:2]
y = wine.target
X_train,X_test,y_train,y_test = train_test_split(X,y)
clf = tree.DecisionTreeClassifier(max_depth=1)
clf.fit(X_train,y_train)
cmap_light = ListedColormap(['#FFAAAA','#AAFFAA','#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000','#00FF00','#0000FF'])
x_min,x_max = X_train[:,0].min() -1,X_train[:,0].max()+1
y_min,y_max = X_train[:,1].min() -1,X_train[:,1].max()+1
xx,yy = np.meshgrid(np.arange(x_min,x_max,.02),np.arange(y_min,y_max,.02))
Z = clf.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(X[:,0],X[:,1],c=y,cmap=cmap_bold,edgecolors='k',s=20)
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.title("Classifier:(max_depth = 1)")
plt.show()
接下来加大深度看看结果有什么变化:
Z = clf2.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(X[:,0],X[:,1],c=y,cmap=cmap_bold,edgecolors='k',s=20)
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.title("Classifier:(max_depth = 3)")
这时候分类器能进行三个分类的识别,而且大部分数据点都进入了正确的分类,接下来继续调整max_depth的值 = 5
接下来用graphviz库来展示一下决策树在每层之中都做了什么过程
export_graphviz(clf2,out_file="wine.dot",class_names=wine.target_names,feature_names=wine.feature_names[:2],impurity=False,filled=True)
with open("wine.dot") as f:
dot_grapg = f.read()
graph = graphviz.Source(dot_grapg)
工具就不下载了,从书上的图示也可以很轻松的看出是怎么运行的
export_graphviz()函数参数:
decision_tree:决策树回归器或分类器,要导出到GraphViz的决策树。
out_file=文件对象或字符串,可选(默认值=无)输出文件的句柄或名称。如果没有,结果将作为字符串返回。
max_depth:int,可选(默认值=无)表示的最大深度。如果没有,则完全生成树。
feature_names:字符串列表,可选(默认值=无)每个功能的名称。
class_names:字符串列表,布尔或无,可选(默认值=无)每个目标类的名称按升序排列。仅与分类相关,不支持多输出。如果为True,则显示类名的符号表示形式。
label:{'all','root','none'},可选(默认值为'all')是否显示杂质等的信息标签。选项包括“全部”在每个节点上显示,“根”仅在顶部根节点上显示,或“无”不在任何节点上显示。
filled:bool,可选(默认值=False)设置为True时,绘制节点以指示分类的多数类、回归值的极值或多重输出的节点纯度。
leaves_parallel:bool,可选(默认值=False)设置为True时,在树的底部绘制所有叶节点。
impurity:bool,可选(默认值=真)设置为True时,显示每个节点上的杂质。
node_ids:bool,可选(默认值=False)设置为True时,显示每个节点上的ID号。
proportion:bool,可选(默认值=False)设置为True时,将“值”和/或“样本”的显示分别更改为比例和百分比。
rotate:bool,可选(默认值=False)设置为True时,将树的方向从左到右,而不是从上到下。
rounded:bool,可选(默认值=False)设置为True时,使用圆角绘制节点框,并使用Helvetica字体而不是Times Roman字体。
special_characters:bool,可选(默认值=False)设置为False时,忽略特殊字符以实现PostScript兼容性。
precision:int,可选(默认值=3)每个节点的杂质值、阈值和值属性中的浮点精度位数。
Return->dot_data:串类型输入树的字符串表示形式为GraphViz点格式。仅当out_文件为None时返回。
学习还是很有效果的,现在读程序没有一开始这么困难了,基本可以知道每个函数的作用,继续努力。
决策树可以非常方便的将模型进行可视化,并不需要对数据进行转换,几乎不用对数据进行预处理。缺点就是不可避免的会出现过拟合的问题,为了解决这一问题可以使用随机森林法
随机森林
RandomForestClassifier()参数:
n_estimators:数值型取值
含义:森林中决策树的个数,默认是10criterion:字符型取值
含义:采用何种方法度量分裂质量,信息熵或者基尼指数,默认是基尼指数max_features:取值为int型, float型, string类型, or None(),默认"auto"
含义:寻求最佳分割时的考虑的特征数量,即特征数达到多大时进行分割。
int:max_features等于这个int值
float:max_features是一个百分比,每(max_features * n_features)特征在每个分割出被考虑。
"auto":max_features等于sqrt(n_features)
"sqrt":同等于"auto"时
"log2":max_features=log2(n_features)
None:max_features = n_featuresmax_depth:int型取值或者None,默认为None
含义:树的最大深度min_samples_split:int型取值,float型取值,默认为2
含义:分割内部节点所需的最少样本数量
int:如果是int值,则就是这个int值
float:如果是float值,则为min_samples_split * n_samplesmin_samples_leaf:int取值,float取值,默认为1
含义:叶子节点上包含的样本最小值
int:就是这个int值
float:min_samples_leaf * n_samplesmin_weight_fraction_leaf : float,default=0.
含义:能成为叶子节点的条件是:该节点对应的实例数和总样本数的比值,至少大于这个min_weight_fraction_leaf值max_leaf_nodes:int类型,或者None(默认None)
含义:最大叶子节点数,以最好的优先方式生成树,最好的节点被定义为杂质相对较少,即纯度较高的叶子节点min_impurity_split:float取值
含义:树增长停止的阀值。一个节点将会分裂,如果他的杂质度比这个阀值;如果比这个值低,就会成为一个叶子节点。min_impurity_decrease:float取值,默认0.
含义:一个节点将会被分裂,如果分裂之后,杂质度的减少效果高于这个值。bootstrap:boolean类型取值,默认True
含义:是否采用有放回式的抽样方式oob_score:boolean类型取值,默认False
含义:是否使用袋外样本来估计该模型大概的准确率n_jobs:int类型取值,默认1
含义:拟合和预测过程中并行运用的作业数量。如果为-1,则作业数设置为处理器的core数。class_weight:dict, list or dicts, "balanced"
含义:如果没有给定这个值,那么所有类别都应该是权重1
对于多分类问题,可以按照分类结果y的可能取值的顺序给出一个list或者dict值,用来指明各类的权重.
"balanced"模式,使用y值自动调整权重,该模式类别权重与输入数据中的类别频率成反比,
即n_samples / (n_classes * np.bincount(y)),分布为第n个类别对应的实例数。
"balanced_subsample"模式和"balanced"模式类似,只是它计算使用的是有放回式的取样中取得样本数,而不是总样本数
原文链接:https://blog.csdn.net/memoryheroli/article/details/80920260
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import tree,datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
wine = datasets.load_wine()
X = wine.data[:,:2]
y = wine.target
X_train,X_test,y_train,y_test = train_test_split(X,y)
forest =RandomForestClassifier(n_estimators=6,random_state=3)
forest.fit(X_train,y_train)
cmap_light = ListedColormap(['#FFAAAA','#AAFFAA','#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000','#00FF00','#0000FF'])
x_min,x_max = X_train[:,0].min() - 1,X_train[:,0].max() + 1
y_min,y_max = X_train[:,1].min() - 1,X_train[:,1].max() + 1
xx,yy = np.meshgrid(np.arange(x_min,x_max,.02),np.arange(y_min,y_max,.02))
Z = forest.predict(np.c_[xx.ravel(),yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(X[:,0],X[:,1],c=y,cmap=cmap_bold,edgecolors='k',s=20)
plt.xlim(xx.min(),xx.max())
plt.ylim(yy.min(),yy.max())
plt.title("Classifier:RandomForest")
plt.show()
实战1----要不要和相亲对象进一步发展
import pandas as pd
data = pd.read_csv('F:\python\决策树\\adult.csv',header=None,index_col=False,
names=['年龄','单位性质','权重','学历','受教育时长','婚姻状况','职业',
'家庭情况','种族','性别','资产所得','资产损失','周工作时长','原籍',
'收入'])
data_lite = data[['年龄','单位性质','学历','性别','周工作时长','职业','收入']]
print(data_lite.head())
为了方便,选取一部分数据
用get_dummies处理数据
features = data_dummies.loc[:,'年龄':'职业_ Transport-moving']
X = features.values
y = data_dummies['收入_ >50K'].values #将收入大于50k的作为预测目标
print("代码运行结果:")
print('特征形态:{} 标签形态:{}'.format(X.shape,y.shape))
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=0)
go_dating_tree = tree.DecisionTreeClassifier(max_depth=5)
go_dating_tree.fit(X_train,y_train)
print("模型得分:{:.2f}".format(go_dating_tree.score(X_test,y_test)))
根据得分可以判断这个模型的预测准确率还是很高的
将一个数据输入测试:
Mr_z = [[37,40,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0]]
dating_dec = go_dating_tree.predict(Mr_z)
if dating_dec == 1:
print("大胆去追求真爱吧!")
else:
print("不用去了,不满足你的要求")
决策树和随机森林亦可帮助用户在数据集中对数据特征的重要性进行判断,就可以让我们通过这两个算法对高维数据集进行分析。
找个人试试去
划分数据集的最大原则是:将无序的数据变得更加有序。在划分数据集之前在之后发生的变化成为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
集合信息度量的方式称为熵
待分类的事务可能划分在多个分类之中
其中p(xi)是选择该分类的概率,下面使用Python计算信息熵
from math import log
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]/numEntries)
shannonEnt -= prob*log(prob,2)
return shannonEnt
def createDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels
myDat,labels = createDataSet()
print(myDat)
print(calcShannonEnt(myDat))
对每个特征划分数据集的结果计算一次信息熵,然后判断按照那个特征划分数据集是最好的方式。
def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
print(splitDataSet(myDat,0,0))
print(splitDataSet(myDat,0,1))
函数输入了三个参数:待划分的数据集,划分数据集的特征,需要返回的特征的值。
Python在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期。为了消除这个不良影响,我们在函数的开始声明一个新的列表对象。遍历数据集中的每一个元素,一旦发现符合要求的值,将其添加到新创建的列表当中。
if语句将符合特征的数据抽取出来。下面进行测试。
接下来遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeatures = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain>bestInfoGain):
bestInfoGain = infoGain
bestFeatures = i
return bestFeatures
该函数实现选取特征,划分数据集,计算得出最高的划分数据集的特征。
该函数的数据必须是一种由列表元素组成的列表,而且所有的列表元素都要有相同的数据长度;
数据的最后一列或者每个实例的最后一个元素是当前实例的类别标签。
一旦满足上述要求,就可以在函数的第一行判定当前数据集包含多少特征属性。
baseEntropy = calcShannonEnt(dataSet)
计算了整个数据集的原始香农熵,保存最初的无序度量值,用于与划分完之后的数据及计算的熵值进行比较。第一个for循环遍历数据集中的所有特征。使用列表推导来创建新的列表,蒋数据集的所有第i个特征值或者所有可能存在的值写入这个新的list中。
第二个for循环便利当前特征中的所有唯一的属性值,对每个唯一属性值划分一次数据集,然后计算新的熵值,并对所有唯一特征值得到的熵求和。
if语句比较所有特征中的信息增益,返回最好特征划分的索引值。
如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时我们通常采用多数表决的方法决定叶子节点的分类。
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
该函数使用分类名称的列表,然后创建键值为classList中唯一值的数据字典,字典对象存储了classList中每个类标签出现的频率,最后用operator操作键值排序字典,并返回出现次数最多的分类名称。
接下来是构建树的函数代码:
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
beatFeatLabel = labels[bestFeat]
myTree = {beatFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[beatFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
count()统计字符串中某个字符出现的次数
该函数使用两个输入参数:数据集和标签列表。后者包含数据集中所有特征的标签。首先创建了classList的列表变量,其中包含了数据集的所有类标签。
递归函数的第一个停止条件是所有的类标签完全相同,则返回该类标签。
第二个停止条件是使用完了所有特征,仍然不能将数据集划分为仅包含唯一类别的分组。使用majorityCnt函数挑选出现次数最多的尅别作为返回值。
当前数据及选取的最好的特征储存在变量bestFeat中,得到列表包含的所有属性值。
subLabels = labels[:] 复制了类标签,将其存储在新列表变量subLabels中。
绘制
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
leafNode = dict(boxstyle = "round4",fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext = centerPt,textcoords = 'axes fraction',
va = "center",ha = "center",bbox = nodeType,arrowprops = arrow_args)
def createPlot():
fig = plt.figure(1,facecolor='white')
fig.clf()
createPlot.ax1 = plt.subplot(111,frameon=False)
plotNode('决策布点',(0.5,0.1),(0.1,0.5),decisionNode)
plotNode('叶节点',(0.8,0.1),(0.3,0.8),leafNode)
plt.show()
createPlot()
matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
讲中文字体设置成楷体以解决在Python中的中文乱码。
这样就进行了叶节点的创建。
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs((secondDict[key]))
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth>maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':
{0:'no',1:'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers':
{0: {'head':{0:'no', 1: 'yes'}},1:'no'}}}},
]
return listOfTrees[i]
创建一个得到叶子结点个数和最大层数的函数。
然后更新函数,得到我们的函数plotTree()
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPtr = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPtr,parentPt,nodeTxt)
plotNode(firstStr,cntrPtr,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ =='dict':
plotTree(secondDict[key],cntrPtr,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),
cntrPtr,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPtr,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
createPlot()是主函数,调用plotTree(),plotTree()又依次调用前面的函数。绘制树形图的很多工作都是在函数plotTree()中完成,首先计算树的宽和高。plotTree.totalW存储树的宽度,全局变量plotTree.totalD存储树的深度,我们使用这两个变量计算树节点的摆放位置。同时通过plotTree.xOff和plotTree.yOff追踪已经绘制的节点位置,以及下一个节点的恰当位置。
然后进行测试。
myTree = retrieveTree(0)
createPlot(myTree)
myTree['no surfacing'][3] = 'maybe'
print(myTree)
createPlot(myTree)
接下来将使用决策树构建分类器,在真实数据上使用决策树分类算法。
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex= featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ =='dict':
classLabel = classify(secondDict[key],featLabels,testVec)
else:
classLabel = secondDict[key]
return classLabel
print(labels)
myTree = retrieveTree(0)
print(myTree)
print(classify(myTree,labels,[1,0]))
print(classify(myTree,labels,[1,1]))
将上述代码添加到tree文件中进行测试
为了节约时间,最好能够在每次执行分类时调用已经构造好的决策树,需要使用pickle模块序列化对象。序列化对象可以在磁盘上保存对象,并在需要时读取出来。
def storeTree(inputTree,filename):
import pickle
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
报错:TypeError: write() argument must be str, not bytes,此时报错的意思是必须是string型而不能是byte型
将'w'修改为‘wb’后出现新的报错:'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence
将fr=open(filename)修改为fr = open(filename,'rb+')
这个报错的意思是:在打开文件时,缺少读写方式。
修改完成后测试成功。
实战----使用决策树预测隐形眼镜类型
from math import log
import operator
import pickle
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
leafNode = dict(boxstyle = "round4",fc = "0.8")
arrow_args = dict(arrowstyle = "<-")
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]/numEntries)
shannonEnt -= prob*log(prob,2)
return shannonEnt
def splitDataSet(dataSet,axis,value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeatures = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain>bestInfoGain):
bestInfoGain = infoGain
bestFeatures = i
return bestFeatures
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
beatFeatLabel = labels[bestFeat]
myTree = {beatFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[beatFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs((secondDict[key]))
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth>maxDepth:
maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':
{0:'no',1:'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers':
{0: {'head':{0:'no', 1: 'yes'}},1:'no'}}}},
]
return listOfTrees[i]
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext = centerPt,textcoords = 'axes fraction',
va = "center",ha = "center",bbox = nodeType,arrowprops = arrow_args)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPtr = (plotTree.xOff + (1.0+float(numLeafs))/2.0/plotTree.totalW,
plotTree.yOff)
plotMidText(cntrPtr,parentPt,nodeTxt)
plotNode(firstStr,cntrPtr,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ =='dict':
plotTree(secondDict[key],cntrPtr,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),
cntrPtr,leafNode)
plotMidText((plotTree.xOff,plotTree.yOff),cntrPtr,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def plotMidText(cntrPt,parentPt,txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1]
createPlot.ax1.text(xMid,yMid,txtString)
def createPlot(inTree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree,(0.5,1.0),'')
plt.show()
def prodict():
with open("lenses.txt", "rb") as fr:
lenses = [inst.decode().strip().split('\t')for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', "tearRate"]
lensesTree = createTree(lenses, lensesLabels)
createPlot(lensesTree)
return lensesTree
def classify(inputTree,featLabels,testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex= featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat,dict):
classLabel = classify(valueOfFeat,featLabels,testVec)
else:
classLabel = valueOfFeat
return classLabel
if __name__ == "__main__":
myTree = prodict()
labels = ['age', 'prescript', 'astigmatic', "tearRate"]
result = classify(myTree, labels, ["presbyopia", "hyper", "yes", "normal"])
if result == 'no lenses':
print("实力良好")
if result == 'soft':
print("轻微近视")
if result == 'hard':
print("重度近视")
我都不敢相信没报错,接下来在代码中解释各部分含义,这样日后复习还能看的清晰一点。
测试结果如下:
先发布,上次就没保存,裂开