机器学习实战(三)

两个星期看了三章,感觉这本书讲的也就那样,幸亏之前看过机器学习的算法,不然感觉应该看起来会很吃力。
书上都是直接讲这个函数是干嘛的,很多时候连函数参数的具体意思都不解释一下。这样的话,书直接给源码得了,讲那么多没卵用。
这是对前几章的感觉,接着看试试。
tree.py
构建决策树:

# -*- coding: utf-8 -*-
"""
Created on Mon Nov 09 21:14:20 2015

@author: hzh
"""

from math import log
import operator
import copy

def createDataSet():
    dataSet = [
        [1, 1, 'yes' ],
        [1, 1, 'yes' ],
        [1, 0, 'no' ],
        [0, 1, 'no' ], 
        [0, 1, 'no' ]
    ]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len( dataSet )
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if( currentLabel not in labelCounts.keys() ):
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float( labelCounts[key] ) / numEntries
        shannonEnt -= prob * log( prob, 2 )
    return shannonEnt

#按照给定的特征划分数据集
def splitDataSet( dataSet, axis, value ):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#选择最哈哦的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len( dataSet[0] ) - 1
    baseEntropy = calcShannonEnt( dataSet )
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet ]
        uniqueVals = set( featList )
        newEntropy = 0.0
        for value in uniqueVals :
            subDataSet = splitDataSet( dataSet, i, value )
            prob = len( subDataSet ) / float( len( dataSet ) )
            newEntropy += prob * calcShannonEnt( subDataSet )
        infoGain = baseEntropy - newEntropy
        if( infoGain > bestInfoGain ):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt( classList ):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys() :
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted( classCount.iteritems(), key = operator.itemgetter(1), reverse = True )
    return sortedClassCount[0][0]

def createTree( dataSet, labels ):
    classList = [ example[-1] for example in dataSet ]
    if( classList.count(classList[0]) == len( classList ) ):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt( classList )
    bestFeat = chooseBestFeatureToSplit( dataSet )
    bestFeatLabel = labels[bestFeat]
    myTree = { bestFeatLabel : {} }
    del(labels[bestFeat])
    featValues = [ example[bestFeat] for example in dataSet ]
    uniqueVals = set( featValues )
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree( splitDataSet( dataSet, bestFeat, value ), subLabels )
    return myTree

#使用决策树的分类函数
def classify( inputTree, featLabels, testVec ):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify( secondDict[key], featLabels, testVec )
            else:
                classLabel = secondDict[key]
    return classLabel

#使用pickle模块存储决策树
def storeTree( inputTree, filename ):
    import pickle
    fw = open( filename, 'w' )
    pickle.dump( inputTree, fw )
    fw.close()

def grabTree( filename ):
    import pickle
    fr = open( filename )
    return pickle.load( fr ) 

#x = [['yes'],['yes'],['no'],['no'],['no']]
#print calcShannonEnt( x )
myData, labels = createDataSet()
mm = copy.deepcopy(labels)
myTree = createTree( myData, labels )
print mm
print classify( myTree, mm, [0,0] )

treePlotter.py
主要是将决策树显示出来的一些函数:

# -*- coding: utf-8 -*-
"""
Created on Thu Nov 12 12:31:07 2015

@author: hzh
"""

import matplotlib.pyplot as plt
import tree

decisionNode = dict( boxstyle = 'sawtooth', fc = '0.8')
leafNode = dict( boxstyle = 'round4', fc = '0.8')
arrow_args = dict( arrowstyle = '<-' )

def plotNode( nodeTxt, centerPt, parentPt, nodeType ):
    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', 
                            xytext = centerPt, textcoords = 'axes fraction', 
                            va = 'center', ha = 'center', bbox = nodeType, arrowprops = arrow_args)

#获取叶节点的数目和树的层数
def getNumLeafs( myTree ):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type( secondDict[key] ).__name__ == 'dict':
            numLeafs += getNumLeafs( secondDict[key] )
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth( myTree ):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type( secondDict[key] ).__name__ == 'dict':
            thisDepth = 1 + getNumLeafs( secondDict[key] )
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

#plotTree函数
def plotMidText( cntrPt, parentPt, txtString ):
    xMid = ( parentPt[0] - cntrPt[0] ) / 2.0 + cntrPt[0];
    yMid = ( parentPt[1] - cntrPt[1] ) / 2.0 + cntrPt[1];
    createPlot.ax1.text( xMid, yMid, txtString )

def plotTree( myTree, parentPt, nodeTxt ):
    numLeafs = getNumLeafs( myTree )
    depth = getTreeDepth( myTree )
    firstStr = myTree.keys()[0]
    cntrPt = ( plotTree.xOff + ( 1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
               plotTree.yOff )
    plotMidText( cntrPt, parentPt, nodeTxt )
    plotNode( firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type( secondDict[key] ).__name__ == 'dict' :
            plotTree( secondDict[key], cntrPt, str(key) )
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode( secondDict[key], ( plotTree.xOff, plotTree.yOff ), cntrPt, leafNode )
            plotMidText( ( plotTree.xOff, plotTree.yOff ), cntrPt, str(key) )
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure( 1, facecolor = 'white' )
    fig.clf()
    axprops = dict( xticks = [], yticks = [] )    
    createPlot.ax1 = plt.subplot( 111, frameon = False, **axprops )
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree( inTree, ( 0.5, 1.0 ), ' ')    
    plt.show()

#myData, labels = tree.createDataSet()
#myTree = tree.createTree( myData, labels )
createPlot(myTree) 
fr = open( 'lenses.txt' )
lenses = [ inst.strip().split('\t') for inst in fr.readlines() ]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate' ]
lensesTree = tree.createTree( lenses, lensesLabels )
createPlot( lensesTree )

lensesTree = tree.grabTree('x.txt')
createPlot( lensesTree )
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值