两个星期看了三章,感觉这本书讲的也就那样,幸亏之前看过机器学习的算法,不然感觉应该看起来会很吃力。
书上都是直接讲这个函数是干嘛的,很多时候连函数参数的具体意思都不解释一下。这样的话,书直接给源码得了,讲那么多没卵用。
这是对前几章的感觉,接着看试试。
tree.py
构建决策树:
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 09 21:14:20 2015
@author: hzh
"""
from math import log
import operator
import copy
def createDataSet():
dataSet = [
[1, 1, 'yes' ],
[1, 1, 'yes' ],
[1, 0, 'no' ],
[0, 1, 'no' ],
[0, 1, 'no' ]
]
labels = ['no surfacing', 'flippers']
return dataSet, labels
#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len( dataSet )
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if( currentLabel not in labelCounts.keys() ):
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float( labelCounts[key] ) / numEntries
shannonEnt -= prob * log( prob, 2 )
return shannonEnt
#按照给定的特征划分数据集
def splitDataSet( dataSet, axis, value ):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#选择最哈哦的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len( dataSet[0] ) - 1
baseEntropy = calcShannonEnt( dataSet )
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet ]
uniqueVals = set( featList )
newEntropy = 0.0
for value in uniqueVals :
subDataSet = splitDataSet( dataSet, i, value )
prob = len( subDataSet ) / float( len( dataSet ) )
newEntropy += prob * calcShannonEnt( subDataSet )
infoGain = baseEntropy - newEntropy
if( infoGain > bestInfoGain ):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt( classList ):
classCount = {}
for vote in classList:
if vote not in classCount.keys() :
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted( classCount.iteritems(), key = operator.itemgetter(1), reverse = True )
return sortedClassCount[0][0]
def createTree( dataSet, labels ):
classList = [ example[-1] for example in dataSet ]
if( classList.count(classList[0]) == len( classList ) ):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt( classList )
bestFeat = chooseBestFeatureToSplit( dataSet )
bestFeatLabel = labels[bestFeat]
myTree = { bestFeatLabel : {} }
del(labels[bestFeat])
featValues = [ example[bestFeat] for example in dataSet ]
uniqueVals = set( featValues )
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree( splitDataSet( dataSet, bestFeat, value ), subLabels )
return myTree
#使用决策树的分类函数
def classify( inputTree, featLabels, testVec ):
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify( secondDict[key], featLabels, testVec )
else:
classLabel = secondDict[key]
return classLabel
#使用pickle模块存储决策树
def storeTree( inputTree, filename ):
import pickle
fw = open( filename, 'w' )
pickle.dump( inputTree, fw )
fw.close()
def grabTree( filename ):
import pickle
fr = open( filename )
return pickle.load( fr )
#x = [['yes'],['yes'],['no'],['no'],['no']]
#print calcShannonEnt( x )
myData, labels = createDataSet()
mm = copy.deepcopy(labels)
myTree = createTree( myData, labels )
print mm
print classify( myTree, mm, [0,0] )
treePlotter.py
主要是将决策树显示出来的一些函数:
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 12 12:31:07 2015
@author: hzh
"""
import matplotlib.pyplot as plt
import tree
decisionNode = dict( boxstyle = 'sawtooth', fc = '0.8')
leafNode = dict( boxstyle = 'round4', fc = '0.8')
arrow_args = dict( arrowstyle = '<-' )
def plotNode( nodeTxt, centerPt, parentPt, nodeType ):
createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',
xytext = centerPt, textcoords = 'axes fraction',
va = 'center', ha = 'center', bbox = nodeType, arrowprops = arrow_args)
#获取叶节点的数目和树的层数
def getNumLeafs( myTree ):
numLeafs = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type( secondDict[key] ).__name__ == 'dict':
numLeafs += getNumLeafs( secondDict[key] )
else:
numLeafs += 1
return numLeafs
def getTreeDepth( myTree ):
maxDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type( secondDict[key] ).__name__ == 'dict':
thisDepth = 1 + getNumLeafs( secondDict[key] )
else:
thisDepth = 1
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
#plotTree函数
def plotMidText( cntrPt, parentPt, txtString ):
xMid = ( parentPt[0] - cntrPt[0] ) / 2.0 + cntrPt[0];
yMid = ( parentPt[1] - cntrPt[1] ) / 2.0 + cntrPt[1];
createPlot.ax1.text( xMid, yMid, txtString )
def plotTree( myTree, parentPt, nodeTxt ):
numLeafs = getNumLeafs( myTree )
depth = getTreeDepth( myTree )
firstStr = myTree.keys()[0]
cntrPt = ( plotTree.xOff + ( 1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
plotTree.yOff )
plotMidText( cntrPt, parentPt, nodeTxt )
plotNode( firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type( secondDict[key] ).__name__ == 'dict' :
plotTree( secondDict[key], cntrPt, str(key) )
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode( secondDict[key], ( plotTree.xOff, plotTree.yOff ), cntrPt, leafNode )
plotMidText( ( plotTree.xOff, plotTree.yOff ), cntrPt, str(key) )
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure( 1, facecolor = 'white' )
fig.clf()
axprops = dict( xticks = [], yticks = [] )
createPlot.ax1 = plt.subplot( 111, frameon = False, **axprops )
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree( inTree, ( 0.5, 1.0 ), ' ')
plt.show()
#myData, labels = tree.createDataSet()
#myTree = tree.createTree( myData, labels )
createPlot(myTree)
fr = open( 'lenses.txt' )
lenses = [ inst.strip().split('\t') for inst in fr.readlines() ]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate' ]
lensesTree = tree.createTree( lenses, lensesLabels )
createPlot( lensesTree )
lensesTree = tree.grabTree('x.txt')
createPlot( lensesTree )