# 使用决策树预测隐形眼镜类型

6 篇文章 0 订阅

### 使用决策树预测隐形眼镜类型

##### 我们有一堆原始数据（lenses.txt）
young	myope	no	reduced	no lenses
young	myope	no	normal	soft
young	myope	yes	reduced	no lenses
young	myope	yes	normal	hard
young	hyper	no	reduced	no lenses
young	hyper	no	normal	soft
young	hyper	yes	reduced	no lenses
young	hyper	yes	normal	hard
pre	myope	no	reduced	no lenses
pre	myope	no	normal	soft
pre	myope	yes	reduced	no lenses
pre	myope	yes	normal	hard
pre	hyper	no	reduced	no lenses
pre	hyper	no	normal	soft
pre	hyper	yes	reduced	no lenses
pre	hyper	yes	normal	no lenses
presbyopic	myope	no	reduced	no lenses
presbyopic	myope	no	normal	no lenses
presbyopic	myope	yes	reduced	no lenses
presbyopic	myope	yes	normal	hard
presbyopic	hyper	no	reduced	no lenses
presbyopic	hyper	no	normal	soft
presbyopic	hyper	yes	reduced	no lenses
presbyopic	hyper	yes	normal	no lenses

* 特征有四个（前4列）：age（年龄）、prescript（症状）、astigmatic（是否散光）、tearRate（眼泪数量）

* 隐形眼镜类别有三类(最后一列)：硬材质(hard)、软材质(soft)、不适合佩戴隐形眼镜(no lenses)

##### 构造决策树代码（trees.py）
# coding:utf-8
from math import log
import operator
import treePlotter
"""

"""
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
#为所有可能分类创建字典
for featVec in dataSet: #the the number of unique elements and their occurance
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
#以2为底求对数, prob为选择该分类的概率
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) #log base 2
return shannonEnt

"""

axis-划分数据集特征
value-特征返回值
"""
def splitDataSet(dataSet, axis, value):          #会把dataSet[axis]==value的那些组数据保留（并去除value值）
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet

"""

"""
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0; bestFeature = -1
for i in range(numFeatures):        #iterate over all the features
featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
uniqueVals = set(featList)       #get a set of unique values
newEntropy = 0.0
for value in uniqueVals:      #uniqueVals中保存的是我们某个样本的特征值的所有的取值的可能性
subDataSet = splitDataSet(dataSet, i, value)
# 在这里划分数据集，比如说第一个特征的第一个取值得到一个子集，第一个特征的特征又会得到另一个特征。当然这是第二次循环
# 我们以第一个特征来说明，根据第一个特征可能的取值划分出来的子集的概率
prob = len(subDataSet)/float(len(dataSet))
# 根据这个代码就可以计算出划分子集的熵（数学期望）
newEntropy += prob * calcShannonEnt(subDataSet)
# 计算出信息增益
infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
if (infoGain > bestInfoGain):       #compare this to the best gain so far
bestInfoGain = infoGain         #if better than current best, set to best
bestFeature = i
return bestFeature                      #returns an integer

"""

"""
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]

"""

"""
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]  #classList保存标签
if classList.count(classList[0]) == len(classList):
return classList[0]#stop splitting when all of the classes are equal
if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)   # 得到列表包含所有属性值
for value in uniqueVals:
subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
return myTree

if __name__ == '__main__':
# del表示删除，与remove区别如下
nums = [1,0, 2 ,0 ,3,0,0]
nums.remove(0)
print nums #[1, 2, 0, 3, 0, 0]
del(nums[0])
print(nums)

# 预测隐形眼镜类型
fr=open('lenses.txt')
lensesLabels=['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree=createTree(lenses, lensesLabels)
print lensesTree
treePlotter.createPlot(lensesTree)  #绘图,treePlotter.py文件如下


##### 绘图代码（treePlotter.py）
# coding: utf-8
"""

"""
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")         # 定义文本框和箭头格式

# def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',  # 绘制带箭头的图解
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

"""

"""
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
numLeafs += getNumLeafs(secondDict[key])
else:   numLeafs +=1
return numLeafs

"""

"""
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
thisDepth = 1 + getTreeDepth(secondDict[key])
else:   thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth

def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]

"""

"""
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[0]     #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key))        #recursion
else:   #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

if __name__ == '__main__':
# 获取叶子节点数，层数测试
mytree = retrieveTree(0)
print mytree
print getNumLeafs(mytree)
print getTreeDepth(mytree)

#createPlot(mytree) #绘图测试

mytree['no surfacing'][3]='maybe'
createPlot(mytree) #绘图测试


• 0
点赞
• 6
收藏
• 打赏
• 0
评论
02-02
07-27 1万+
04-14 985
06-14 735
09-06 5255
07-09 2219
07-30 1041
11-10 1228

### “相关推荐”对你有帮助么？

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助

¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、C币套餐、付费专栏及课程。