理论方面机器学习实战中讲的非常清楚,深入点的话在西瓜书可以参考,这里只把源码贴出来和学习中的一些困难。
这里主要主要是有这么几块:
首先搞懂信息熵和其作用
划分数据集
递归构建决策树
Matplotlib注解绘制树形图
测试和存储分类器
示例:使用决策树预测隐形眼镜类型
构建一个决策树:
from math import log
import operator
#import pickle
#import tree_plot
# 自己建立的数据
def createDataSet():
# 各元素也是列表
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers'] # 特征标签
return dataSet, labels
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
# 数据集中实例的总数
numEntries = len(dataSet)
# 频数统计
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # 最后一列作为标签
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 统计不同的标签数
#print labelCounts
shannonEnt = 0.0
# 计算香农熵
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) # 以2为底
return shannonEnt
# 按照给定特征划分数据集,返回剩余特征
# 三个参数:待划分的数据集,划分的特征,需返回的特征的值
def splitDataSet(dataSet, axis, value):
# 创建新的列表,避免影响原数据
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value: # 提取特定特征中的特定值
reducedFeatVec = featVec[:axis] # 得到axis列之前列的特征
# 在处理多个列表时,append()是添加的列表元素,而extend()添加的是元素
reducedFeatVec.extend(featVec[axis+1:]) # 得到axis列之后列的特征
retDataSet.append(reducedFeatVec)
return retDataSet # 返回指定划分特征外的特征
# 选择最好的数据集划分方式,根据熵计算
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 留出来一列,把最后一列作为类别标签
baseEntropy = calcShannonEnt(dataSet) # 计算原始香浓熵,划分前
bestInfoGain = 0.0; bestFeature = -1 # 赋初值
for i in range(numFeatures): # 迭代所有的特征
# 使用列表推导来创建新的列表,featlist得到的是每列的特征元素
featList = [example[i] for example in dataSet]
print 'featList:',featList
uniqueVals = set(featList) # 转换为集合,以此确保其中的元素的唯一性
print 'uniqueVals:',uniqueVals
newEntropy = 0.0
for value in uniqueVals:
# 每一列按照不重复的元素划分,返回剩余特征
subDataSet = splitDataSet(dataSet, i, value)
print 'subDataSet:',subDataSet
prob = len(subDataSet)/float(len(dataSet)) # 频率
# 得到此次划分的熵,此处的prob和calcShannonEnt()中的prob不是同一种,第一个
# 是实例数在整体数组的频率,第二个是部分数组中的标签频率,newEntropy求的是信息期望
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 计算信息增益,即熵的减少
print 'infogain:',infoGain
if (infoGain > bestInfoGain): #如果信息量减少,就把减少量作为基准
bestInfoGain = infoGain
bestFeature = i
return bestFeature # 返回信息信息增益最高的特征列
# 多数表决法决定叶节点分类
def majorityCnt(classList): # classList是分类名称的列表
classCount={} # 存储每个类标签出现的频率
for vote in classList:
# 统计所有的不重复的key
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
# 对分类进行倒序排列
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),\
reverse=True)
# 返回出现次数最多的名称
return sortedClassCount[0][0]
# 创建树的函数代码
def createTree(dataSet,labels): # labels存储的是特征标签
classList = [example[-1] for example in dataSet] # 数据集的最后一列作为类标签列表
print 'classList:',classList
# 判断类别是否完全相同,通过查看类标签的第一个的数目
if classList.count(classList[0]) == len(classList):
return classList[0]
# 判断是否遍历完所有的特征,通过查看剩下的特征数是不是剩下一个
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 找到当前数据集最好的特征的索引
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}} # 嵌套字典,得到一个当前最好的特征标签
del(labels[bestFeat]) # 删除当前最好的特征标签。即划分时类别特征数目减少
# 列表推到式得到数据集的最好特征标签的那列
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) # 集合确保元素的唯一性
for value in uniqueVals:
subLabels = labels[:] # 得到当前剩余的特征标签
# 递归调用
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,\
bestFeat, value),subLabels)
return myTree
# 创建自己的数据集
mydat,labels=createDataSet() # mydat,lables相当于全局变量
print calcShannonEnt(mydat)
mydat[0][-1]='maybe'
print calcShannonEnt(mydat)
print(splitDataSet(mydat,1,1))
print(splitDataSet(mydat,0,1))
print('bestFeature:',chooseBestFeatureToSplit(mydat))
print '.............................'
print(createTree(mydat,labels))
看运行的结果:
0.970950594455
1.37095059445
[[1, 'maybe'], [1, 'yes'], [0, 'no'], [0, 'no']]
[[1, 'maybe'], [1, 'yes'], [0, 'no']]
featList: [1, 1, 1, 0, 0]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no'], [1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no']]
infogain: 0.419973094022
featList: [1, 1, 0, 1, 1]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no'], [0, 'no']]
infogain: 0.170950594455
('bestFeature:', 0)
.............................
classList: ['maybe', 'yes', 'no', 'no', 'no']
featList: [1, 1, 1, 0, 0]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no'], [1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no']]
infogain: 0.419973094022
featList: [1, 1, 0, 1, 1]
uniqueVals: set([0, 1])
subDataSet: [[1, 'no']]
subDataSet: [[1, 'maybe'], [1, 'yes'], [0, 'no'], [0, 'no']]
infogain: 0.170950594455
classList: ['no', 'no']
classList: ['maybe', 'yes', 'no']
featList: [1, 1, 0]
uniqueVals: set([0, 1])
subDataSet: [['no']]
subDataSet: [['maybe'], ['yes']]
infogain: 0.918295834054
classList: ['no']
classList: ['maybe', 'yes']
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'maybe'}}}}
上述结果显示了算法运行的具体过程,最后一行得到了决策树的数据结构,用的是python的字典来存储的,可以看出是一种层级关系。
在函数createTree(dataSet,labels)中:
myTree = {bestFeatLabel:{}} # 嵌套字典,得到一个当前最好的特征标签
# 递归调用
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,\
bestFeat, value),subLabels)
这两行我起初觉得myTree字典被重置,其实是递归的作用,我的理解是这样的
字典递归
还有就是递归终止条件的判断要注意。。
使用Matplotlib绘制树形图:
可能会用到的函数:
matplotlib(1)
matplotlib(2)
matplotlib(3)
# -*- coding: utf-8 -*-
"""
绘制树节点
Created on Thu Aug 10 10:37:02 2017
@author: LiLong
"""
#import decision_tree.py
import matplotlib.pyplot as plt
# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8") # 定义决策树的叶子结点的描述属性
arrow_args = dict(arrowstyle="<-") # 定义箭头属性,也可以是<->,效果就变成双箭头的了
# 绘制结点文本和指向
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#nodeTxt为要显示的文本,xytext是文本的坐标,
#xy是注释点的坐标 ,nodeType是注释边框的属性,arrowprops连接线的属性
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0] # 得到第一个键
secondDict = myTree[firstStr] # 得到第一个键对应的值
for key in secondDict.keys():
# 测试节点的数据类型是否是字典
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key]) # 又是递归调用
else: numLeafs +=1
return numLeafs # 返回叶节点数
# 获取树的层数(递归在此就像是一层一层的剥到最里面,然后再从里到外加起来)
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys(): #keys()函数得到的是key,是一个列表
#print'key:',key
# 测试节点的数据类型是否是字典,如果是字典说明是可以再分的,深度+1
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用,层层剥离字典
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 绘制中间文本的坐标和显示内容,即父子之间的填充文本
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] # 求中间点的横坐标
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# 绘制出来此文本
createPlot.ax1.text(xMid, yMid, txtString)
# 绘制树形图
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 得到叶节点的数,宽
print 'numLeafs:',numLeafs
depth = getTreeDepth(myTree) # 获得树的层数,高
firstStr = myTree.keys()[0] # 得到第一个划分的特征
# 计算坐标
print 'plotTree.xOff:',plotTree.xOff
print 'plotTree.totalW:',plotTree.totalW
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff)
print 'cntrPt:',cntrPt
# cntrPt是刚计算的坐标,parentPt是父节点坐标,nodeTxt目前为空字符
plotMidText(cntrPt, parentPt, nodeTxt) # 绘制连接线上的文本
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制树节点
secondDict = myTree[firstStr] # 下一级字典,即下一层
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 纵坐标降低
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': # 如果是树节点
plotTree(secondDict[key],cntrPt,str(key)) #递归,绘制
else: #如果是一个叶节点,就绘制出来
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 定x坐标
# secondDict[key]叶节点文本,(plotTree.xOff, plotTree.yOff)箭头指向的坐标
# cntrPt注释(父节点)的坐标
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制文本及连线
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) # 绘制父子填充文本
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# 预留树信息
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
# Axis为坐标轴,Label为坐标轴标注。Tick为刻度线,ax是坐标系区域
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
# 横纵坐标轴的刻度线,应该为空,加上范围后,父子间的节点连线的填充文本位置错乱
axprops = dict(xticks=[], yticks=[]) # {'xticks': [], 'yticks': []}
# createPlot.ax1创建绘图区,无边框,无刻度值
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
#createPlot.ax1 = plt.subplot(111, frameon=False)
# 计算树形图的全局变量,用于计算树节点的摆放位置,将树绘制在中心位置
plotTree.totalW = float(getNumLeafs(inTree)) # plotTree.totalW保存的是树的宽
plotTree.totalD = float(getTreeDepth(inTree)) # plotTree.totalD保存的是树的高
plotTree.xOff = -0.5/plotTree.totalW # 决策树起始横坐标
plotTree.yOff = 1.0 # 决策树的起始纵坐标
plotTree(inTree, (0.5,1.0), '') # 绘制树形图
plt.show() # 显示
mytree=retrieveTree(0)
getNumLeafs(mytree)
getTreeDepth(mytree)
createPlot(mytree)
运行结果:
numLeafs: 3
plotTree.xOff: -0.166666666667
plotTree.totalW: 3.0
cntrPt: (0.5, 1.0)
numLeafs: 2
plotTree.xOff: 0.166666666667
plotTree.totalW: 3.0
cntrPt: (0.6666666666666666, 0.5)
决策树图的上方代码是算法过程中的一些参数变化,有助于理解。其中决策树绘制过程中坐标的计算有点复杂。。。
下面是一些简单的知识点:
函数也是对象,给一个对象绑定一个属性就是这样的:
def f():
pass
f.a = 1
print f.a
>>> os.getcwd()
'C:\\Users\\LiLong'
>>> os.chdir('C:\\Users\\LiLong\\Desktop\\decision_tree')
>>> os.getcwd()
'C:\\Users\\LiLong\\Desktop\\decision_tree'
>>>
使用决策树分类并预测隐形眼镜类型
tree_plot.py
# -*- coding: utf-8 -*-
"""
绘制树节点
Created on Thu Aug 10 10:37:02 2017
@author: LiLong
"""
#import decision_tree.py
import matplotlib.pyplot as plt
# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8") # 定义决策树的叶子结点的描述属性
arrow_args = dict(arrowstyle="<-") # 定义箭头属性,也可以是<->,效果就变成双箭头的了
# 绘制结点文本和指向
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
#nodeTxt为要显示的文本,xytext是文本的坐标,
#xy是注释点的坐标 ,nodeType是注释边框的属性,arrowprops连接线的属性
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
# 获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys()[0] # 得到第一个键
secondDict = myTree[firstStr] # 得到第一个键对应的值
for key in secondDict.keys():
# 测试节点的数据类型是否是字典
if type(secondDict[key]).__name__=='dict':
numLeafs += getNumLeafs(secondDict[key]) # 又是递归调用
else: numLeafs +=1
return numLeafs # 返回叶节点数
# 获取树的层数(递归在此就像是一层一层的剥到最里面,然后再从里到外加起来)
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys(): #keys()函数得到的是key,是一个列表
#print'key:',key
# 测试节点的数据类型是否是字典,如果是字典说明是可以再分的,深度+1
if type(secondDict[key]).__name__=='dict':
thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用,层层剥离字典
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 绘制中间文本的坐标和显示内容,即父子之间的填充文本
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] # 求中间点的横坐标
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
# 绘制出来此文本
createPlot.ax1.text(xMid, yMid, txtString)
# 绘制树形图
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree) # 得到叶节点的数,宽
print 'numLeafs:',numLeafs
depth = getTreeDepth(myTree) # 获得树的层数,高
firstStr = myTree.keys()[0] # 得到第一个划分的特征
# 计算坐标
print 'plotTree.xOff:',plotTree.xOff
print 'plotTree.totalW:',plotTree.totalW
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, \
plotTree.yOff)
#print 'cntrPt:',cntrPt
# cntrPt是刚计算的坐标,parentPt是父节点坐标,nodeTxt目前为空字符
plotMidText(cntrPt, parentPt, nodeTxt) # 绘制连接线上的文本
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 绘制树节点
secondDict = myTree[firstStr] # 下一级字典,即下一层
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 纵坐标降低
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': # 如果是树节点
plotTree(secondDict[key],cntrPt,str(key)) #递归,绘制
else: #如果是一个叶节点,就绘制出来
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 定x坐标
# secondDict[key]叶节点文本,(plotTree.xOff, plotTree.yOff)箭头指向的坐标
# cntrPt注释(父节点)的坐标
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制文本及连线
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) # 绘制父子填充文本
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
# 预留树信息
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
# Axis为坐标轴,Label为坐标轴标注。Tick为刻度线,ax是坐标系区域
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
# 横纵坐标轴的刻度线,应该为空,加上范围后,父子间的节点连线的填充文本位置错乱
axprops = dict(xticks=[], yticks=[]) # {'xticks': [], 'yticks': []}
# createPlot.ax1创建绘图区,无边框,无刻度值
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
#createPlot.ax1 = plt.subplot(111, frameon=False)
# 计算树形图的全局变量,用于计算树节点的摆放位置,将树绘制在中心位置
plotTree.totalW = float(getNumLeafs(inTree)) # plotTree.totalW保存的是树的宽
plotTree.totalD = float(getTreeDepth(inTree)) # plotTree.totalD保存的是树的高
plotTree.xOff = -0.5/plotTree.totalW # 决策树起始横坐标
plotTree.yOff = 1.0 # 决策树的起始纵坐标
plotTree(inTree, (0.5,1.0), '') # 绘制树形图
plt.show() # 显示
decision_tree.py
# coding=utf-8
from math import log
import operator
import pickle
import tree_plot # 导入decision_tree.py
# 自己建立的数据
def createDataSet():
# 各元素也是列表
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing','flippers'] # 特征标签
return dataSet, labels
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
# 数据集中实例的总数
numEntries = len(dataSet)
# 频数统计
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # 最后一列作为标签
if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 统计不同的标签数
#print labelCounts
shannonEnt = 0.0
# 计算香农熵
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
shannonEnt -= prob * log(prob,2) # 以2为底
return shannonEnt
# 按照给定特征划分数据集,返回剩余特征
# 三个参数:待划分的数据集,划分的特征,需返回的特征的值
def splitDataSet(dataSet, axis, value):
# 创建新的列表,避免影响原数据
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value: # 提取特定特征中的特定值
reducedFeatVec = featVec[:axis] # 得到axis列之前列的特征
# 在处理多个列表时,append()是添加的列表元素,而extend()添加的是元素
reducedFeatVec.extend(featVec[axis+1:]) # 得到axis列之后列的特征
retDataSet.append(reducedFeatVec)
return retDataSet # 返回指定划分特征外的特征
# 选择最好的数据集划分方式,根据熵计算
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 留出来一列,把最后一列作为类别标签
baseEntropy = calcShannonEnt(dataSet) # 计算原始香浓熵,划分前
bestInfoGain = 0.0; bestFeature = -1 # 赋初值
for i in range(numFeatures): # 迭代所有的特征
# 使用列表推导来创建新的列表,featlist得到的是每列的特征元素
featList = [example[i] for example in dataSet]
print 'featList:',featList
uniqueVals = set(featList) # 转换为集合,以此确保其中的元素的唯一性
print 'uniqueVals:',uniqueVals
newEntropy = 0.0
for value in uniqueVals:
# 每一列按照不重复的元素划分,返回剩余特征
subDataSet = splitDataSet(dataSet, i, value)
print 'subDataSet:',subDataSet
prob = len(subDataSet)/float(len(dataSet)) # 频率
# 得到此次划分的熵,此处的prob和calcShannonEnt()中的prob不是同一种,第一个
# 是实例数在整体数组的频率,第二个是部分数组中的标签频率,newEntropy求的是信息期望
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 计算信息增益,即熵的减少
print 'infogain:',infoGain
if (infoGain > bestInfoGain): #如果信息量减少,就把减少量作为基准
bestInfoGain = infoGain
bestFeature = i
return bestFeature # 返回信息信息增益最高的特征列
# 多数表决法决定叶节点分类
def majorityCnt(classList): # classList是分类名称的列表
classCount={} # 存储每个类标签出现的频率
for vote in classList:
# 统计所有的不重复的key
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
# 对分类进行倒序排列
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),\
reverse=True)
# 返回出现次数最多的名称
return sortedClassCount[0][0]
# 创建树的函数代码
def createTree(dataSet,labels): # labels存储的是特征标签
classList = [example[-1] for example in dataSet] # 数据集的最后一列作为类标签列表
print 'classList:',classList
# 判断类别是否完全相同,通过查看类标签的第一个的数目
if classList.count(classList[0]) == len(classList):
return classList[0]
# 判断是否遍历完所有的特征,通过查看剩下的特征数是不是剩下一个
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 找到当前数据集最好的特征的索引
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}} # 嵌套字典,得到一个当前最好的特征标签
del(labels[bestFeat]) # 删除当前最好的特征标签。即划分时类别特征数目减少
# 列表推到式得到数据集的最好特征标签的那列
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues) # 集合确保元素的唯一性
for value in uniqueVals:
subLabels = labels[:] # 得到当前剩余的特征标签
# 递归调用
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,\
bestFeat, value),subLabels)
return myTree
# 使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
firstStr = inputTree.keys()[0] # 树的第一个键
secondDict = inputTree[firstStr] # 第一个键对应的值(字典)
featIndex = featLabels.index(firstStr) # 第一个键(特征)在特征列表中的索引
print 'featIndex:',featIndex
key = testVec[featIndex] # key是相应特征对应测试列表中的的取值,也即是父子节点间的判断
print 'key:',key
valueOfFeat = secondDict[key] #
print 'valueOfFeat:',valueOfFeat
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec)
else: classLabel = valueOfFeat
return classLabel
# 使用pickle模块储存决策树
def storeTree(inputTree,filename):
with open(filename,'w') as fw:
pickle.dump(inputTree,fw)
def grabTree(filename):
with open(filename,'r') as fr:
return pickle.load(fr)
# 执行分类
mydat,labels=createDataSet() # mydat,lables相当于全局变量
myTree=tree_plot.retrieveTree(0) # 树字典
print classify(myTree,labels,[1,0]) # 输出预测类型
# 预测隐形眼镜类型
with open('lenses.txt','r') as fr:
# '\t'是tab分隔符,得到的是数组[[],[]....]
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels=['age','prescript','astigmatic','tearRate']
lensesTree=createTree(lenses,lensesLabels)
#storeTree(lensesTree,'clf.txt')
print 'load:',grabTree('clf.txt')
tree_plot.createPlot(lensesTree)
运行结果:
no
load: {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}
由此得到了决策树。。
此处还有一个问题没有解决:就是
myTree=tree_plot.retrieveTree(0) # 树字典
数字典用的是写好的,也可以说是运行得到的树字典,但是间接的。
如果直接用运行得到的字典
# 执行分类'
mydat,labels=createDataSet() # mydat,lables相当于全局变量
#myTree=tree_plot.retrieveTree(0) # 树字典
myTree=createTree(dataSet,labels)
print classify(myTree,labels,[1,0]) # 输出预测类型
报错:
featIndex = featLabel.index(str(firstStr)) # 第一个键(特征)在特征列表中的索引
ValueError: 'no surfacing' is not in list
这个问题还没解决。。