机器学习(二):决策树算法

博客参考书籍:《Machine Learing in Action》 Peter Harrington
使用语言:Python
Python需要装载的库有:numpy ,matplotlib。安装参考博客:

http://blog.csdn.net/zhangyuehuan/article/details/39134747
Python相关函数:
extend()。在原矩阵的基础上进行扩展。比如[2,1,1].extend([1,1])=[2,1,1,1,1].

决策树算法:它是一种典型的分类算法,将样本数据按照分类因素构造决策树,当对新数据进行判断时,将其按照决策树,逐渐选择分支,最终确认新数据的分类。比如,将生物进行分类:先按照是否是动物分为动物类及植物类,然后对动物类按照生活环境分为陆生、水生、两栖类,以此类推。

优点:计算简单,容易被人理解,允许丢失数据,可以处理不相干数据。
缺点:容易过度拟合。

决策树构建算法:ID3算法。

算法介绍:
算法流程图
决策树构建采用递归算法,对样本数据按照最优分类因素进行分类,将其分为决策树的若干个分支,然后对分支继续调用决策树算法,最终将其分割成不可分类的分支。不可分类分支的判别标准为:1)分支数据属于同一分类。2)虽然分支数据不属于同一分类,但已经没有分类因素了。

使用决策树对数据进行判定:将数据依次与决策树的分支因素比较,将其分类,最终到达某一分支,从而判断其类型。

相关模块:在创建决策树中,有个核心的模块–找到最优分类因素,并按照分类因素,将其分为若干子树。那么如何判断最优分类因素呢?我们可以采用香农熵的概念,熵值越大,则表示该信源的信息分布越均匀,而按照最优分类因素将数据分类后,可以将数据的信息熵变得最小,即分布集中化。信源熵的计算为:
H(x) = E[I(xi)] = E[ log(2,1/p(xi)) ] = -∑p(xi)log(2,p(xi)) (i=1,2,..n)
计算方法举例如下:原始数据数量为(N+M)按照因素A分为俩部分,B,C。B的数量为N,C的数量为M,则此时对应的信息熵为:
N*H(B)/(M+N)+M*H(C)/(M+N)
所以在此需要三个模块:1、香农熵计算。2、将数据按照某一因素对应的值分出对应的模块。3、根据香农熵,找出最优分类因素。

#coding=utf-8

from math import log
import operator

'''
Function:   计算数据的香农熵
Input:  样本表,样本末为归类
'''
def calcShannonEntropy(dataSet):
    dataNum = len(dataSet)
    labelCount = {}
    shannonEntropy = 0.0
    for data in dataSet:#统计每个类别的数量
        dataLabel = data[-1]
        if dataLabel not in labelCount.keys():
            labelCount[dataLabel] = 0#如果是第一次统计
        labelCount[dataLabel] += 1
    for key in labelCount:
        pro = float(labelCount[key]) / dataNum
        shannonEntropy -= pro * log(pro, 2)
    return shannonEntropy

'''
Function:   按照分类因素,返回分割好的矩阵
Input: dataSet,样本表  axis分类因素所在的位置,value分类因素对应的值
Output:分类因素为value的矩阵,但去掉了该分类因素
''' 
def spiltDataSet(dataSet, axis, value):
    returnMat = []
    for data in dataSet:
        if data[axis] == value:#如果是该类的矩阵
            newData = data[:axis]
            newData.extend(data[axis+1:])
            returnMat.append(newData)
    return returnMat

'''
Function:   找到最合适的分类因素
Input:      dataSet需要分类的矩阵, labels对应的分类因素
Output: 分类因素对应的标号
'''
def foundBestFactor(dataSet, labels):
    axis = -1
    dataNum = len(dataSet)
    factorNum = len(labels)
    minShannonEntropy = calcShannonEntropy(dataSet)
    tempShannonEntropy = 0.0
    for i in range(factorNum):
        tempValues = [data[i] for data in dataSet]#获取对应因素的值矩阵
        values = set(tempValues)
        for value in values:
            tempMat = spiltDataSet(dataSet, i, value)
            pro = float(len(tempMat)) / dataNum
            tempShannonEntropy += pro * calcShannonEntropy(tempMat)
        if(tempShannonEntropy < minShannonEntropy):
            minShannonEntropy = tempShannonEntropy
            axis = i
        tempShannonEntropy = 0.0
    return axis


def createDataSet():
    dataSet = [[1, 1, 'yes'],
    [1, 1, 'yes'],
    [1, 0, 'no'],
    [0, 1, 'no'],
    [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    return dataSet, labels

至此,模块构建完成,可以按照算法流程图,构造决策树了,对于每个决策树的叶子节点,都应该有一个具体的分类,但有时存在一个叶子节点有多个类别,则我们从中选取一个占据最多的类别。
下面则为确定叶子种类的算法:

'''
Function:   找到最合适的所属类别
Input:      分类列表
Output: 叶子节点对应的分类
'''     
def selectBestLabel(classList):
    classCount = {}
    for label in classList:
        if label not in classCount:
            classCount[label] = 0
        classCount[label] += 1
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]
'''
Function:   创建决策树,叶子节点为对应的class类别
Input:      dataSet样本列表, labels判断因素
Output: 决策树
'''     
def createTree(dataSet, labels):
    classList = [data[-1] for data in dataSet]
    if classList.count(classList[0]) == len(classList):#如果所有样本都属于同一类
        return classList[0]
    if len(dataSet[0]) == 1:
        return selectBestLabel(dataSet)#如果已经无分类因素了,则找到最多的类别
    bestFac = foundBestFactor(dataSet, labels)
    bestFacLabel = labels[bestFac]#最好的分类因素对应的标签
    del(labels[bestFac])#删除掉该因素
    myTree = {bestFacLabel:{}}
    tempValues = [data[bestFac] for data in dataSet]
    values = set(tempValues)
    for value in values:
        mLabels = labels[:]#保留住当前label,要多次使用
        myTree[bestFacLabel][value] = createTree(spiltDataSet(dataSet, bestFac, value), mLabels)
    return myTree

至此,决策树构建完毕,但是如果每次都需要进行构建决策树操作,会耗费大量资源,因此,需要以一定的形式保存决策树,python带有的pickle可以完美的解决。

'''
Function:   保存决策树
Input:      tree决策树
Output: 文件
''' 
def saveTree(tree, filename):
    import pickle 
    file = open(filename,'w')
    pickle.dump(tree, file)
    file.close()

'''
Function:   加载决策树
Input:      文件名
Output: 决策树
''' 
def loadTree(filename):
    import pickle 
    file = open(filename)
    return pickle.load(file)

构建决策树完成,现在可以使用决策树去判定数据了。

'''
Function:   使用决策树分类
Input:      决策树,决策因素,测试数据
Output: 分类
''' 
def classify(inputTree, facLabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    facIndex = facLabels.index(firstStr)
    key = testVec[facIndex]
    valueOfFac = secondDict[key]
    if isinstance(valueOfFac, dict): 
        classLabel = classify(valueOfFac, facLabels, testVec)
    else: classLabel = valueOfFac
    return classLabel

至此,所有的决策树算法完成。
为了更好的看出决策树,可以使用matplotpy库来画出来。代码(来自参考书籍)如下:

'''
Created on Oct 14, 2010

@author: Peter Harrington
'''
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

#createPlot(thisTree)

以上,则是ID3算法构建决策树的所有内容。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值