机器学习实战之分类篇 一 决策树(从银行放贷到相亲的if-else)

决策树(Decision Tree):

      决策树模型是一种描述对实例进行分类的树形结构。决策树由结点(node)和有向边(directed edge)组成。结点有两种类型:内部结点(internal node)和叶结点(leaf node)。内部结点表示一个特征或属性,叶结点表示一个类。

决策树构建:

       通常,这一过程可以概括为3个步骤:特征选择、决策树的生成和决策树的修剪。

       特征选择在于选取对训练数据具有分类能力的特征。这样可以提高决策树学习的效率,如果利用一个特征进行分类的结果与随机分  类的结果没有很大差别,则称这个特征是没有分类能力的。经验上扔掉这样的特征对决策树学习的精度影响不大。通常特征选择的标准是信息增益(information gain)或信息增益比。

      直观上,如果一个特征具有更好的分类能力,或者说,按照这一特征将训练数据集分割成子集,使得各个子集在当前条件下有最好的分类,那么就更应该选择这个特征。信息增益就能够很好地表示这一直观的准则。在划分数据集之后信息发生的变化称为信息增益,知道如何计算信息增益,就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

      集合信息的度量方式称为香农熵或者简称为熵(entropy)。

     香农熵的数学表达式:H = -p(x_i)\sum_{i=1}^{n}(\log_2(p(x_i))

    条件熵H(Y|X)表示在已知随机变量X的条件下随机变量Y的不确定性,随机变量X给定的条件下随机变量Y的条件熵(conditional entropy)H(Y|X),定义为X给定条件下Y的条件概率分布的熵对X的数学期望:

                          H(Y|X) = \sum_{i=1}^{n}p_iH(Y|X=x_i) \ \ \ \ where \ \ p_i = P(X=x_i)

   对于具有多个特征的数据,需要别计算起对应香农熵和条件熵。信息增益是相对于特征而言的。所以,特征A对训练数据集D的信息增益G(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差,即:

                                                           G(D,A) = H(D) - H(D|A)  

     以是否给放贷为例,i = 0, 1(0表示否,1表示是),则香农熵为H(loan) =- \sum_{i=0}^{1}p_i\log_2(p_i) = 0.978。在放贷条件下,其他特征的条件增益和信息增益为:

                                        G(D,A_i) = H(D) - H(D,A_i) = H(D) - \sum_{j=1}^{n}p_{i,j}H(D_j)

决策树生成:   

       递归创生成决策树时,递归有两个终止条件:第一个停止条件是所有的类标签完全相同,则直接返回该类标签;第二个停止条件是使用完了所有特征,仍然不能将数据划分仅包含唯一类别的分组,即决策树构建失败,特征不够用。此时说明数据纬度不够,由于第二个停止条件无法简单地返回唯一的类标签,这里挑选出现数量最多的类别作为返回值。

相亲问题:

       假想的相亲对象分类系统如上图。它首先检测相亲对方是否有房。如果有房,则对于这个相亲对象可以考虑进一步接触。如果没有房,则观察相亲对象是否有上进心,如果没有,直接Say Goodbye,此时可以说:"你人很好,但是我们不合适。"如果有,则可以把这个相亲对象列入候选名单,好听点叫候选名单,有点瑕疵地讲,那就是备胎。

     使用决策树算法构建上面图谱,实现简单相亲决策(数据来源上图)。

import matplotlib.pyplot as plt
from math import log
import operator
'''
Function : createDataSet()
    Description : to createDataSet for user loan info
    Args : None
    Rets : featureMatrix, labels
'''

def createDataSet():
    #featureMatrix info age, job, house, loanInfo
    featureMatrix = [
        [0, 0, 0, 0, 'no'],
        [0, 0, 0, 1, 'no'],
        [0, 1, 0, 1, 'yes'],
        [0, 1, 1, 0, 'yes'],
        [0, 0, 0, 0, 'no'],
        [1, 0, 0, 0, 'no'],
        [1, 0, 0, 1, 'no'],
        [1, 1, 1, 1, 'yes'],
        [1, 0, 1, 2, 'yes'],
        [1, 0, 1, 2, 'yes'],
        [2, 0, 1, 2, 'yes'],
        [2, 0, 1, 1, 'yes'],
        [2, 1, 0, 1, 'yes'],
        [2, 1, 0, 2, 'yes'],
        [2, 0, 0, 0, 'no']
    ]
    labels = ['age', 'job', 'house', 'loanInfo']
    return featureMatrix, labels

'''
Function : shannonEntropy(dataSet)
    Description : to calculate the shannon entropy of dataSet
    Args : dataSet
    Rets : H #shannon entropy
'''
def shannonEntropy(featureMatrix):
    rows = len(featureMatrix)
    labels = {}
    #frequency statistics
    for row in featureMatrix:
        label = row[-1]
        if label not in labels.keys():
            labels[label] = 0
        labels[label] += 1
    H = 0.0
    for i in labels:
        p_xi = float(labels[i]) / rows
        H -= p_xi * log(p_xi, 2)
    return H
'''
Function : splitDataSet(featureMatrix, axis, value)
    Description : to splict dataset with axis subvector when value in axis
    Args :  featureMatrix
            axis
            value
    Rets : subDataSet
'''
def splitDataSet(featureMatrix, axis, value):
    subDataSet = []
    for row in featureMatrix:
        #get the feature set except feature itself
        if row[axis] == value:
            reducedFeature = row[:axis]
            reducedFeature.extend(row[axis+1:])
            subDataSet.append(reducedFeature)
    return subDataSet
'''
Function : maxEntropy(featureMatix)
    Description : to get the max entropy
    Args : featureMatrix
    Rets : bestFeature
'''
def maxEntropy(featureMatrix):
    #get numFeature
    numFeature = len(featureMatrix[0]) - 1
    baseEntropy = shannonEntropy(featureMatrix)
    maxInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeature):
        featureList = [row[i] for row in featureMatrix]
        uniqueValues = set(featureList)
        newEntropy = 0.0
        #calculate every value's entropy
        for value in uniqueValues:
            subDataSet = splitDataSet(featureMatrix, i, value)
            p_i = len(subDataSet) / float(len(featureMatrix))
            newEntropy += p_i * shannonEntropy(subDataSet)
        infoGain = baseEntropy - newEntropy
        print('%dth feature info gain is : %.3f'%(i, infoGain))
        #update maxInfoGain and bestFature
        if(infoGain > maxInfoGain):
            maxInfoGain = infoGain
            bestFeature = i
    return bestFeature

'''
Function : majorityCount(classList)
    Description : to count majority class
    Args : classList
    Rets : sortedClassList[0][0]
'''
def majorityCount(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    #get the max frequency bestFeature and value
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

'''
Function : createTree(featureMatrix, labels, featureLabels)
    Description : to create tree structure with featureMatrix
    Args :  featureMatrix
            labels 
            featureMatrixLabels
    Rets :  dTree
'''
def createTree(featureMatrix, labels, featureLabels):
    #get loanInfo
    classList = [row[-1] for row in featureMatrix]
    #if feature is all in classList return
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    #if the class has not been classified, return max frequency result
    if len(featureMatrix[0]) == 1 or len(labels) == 0:
        return majorityCount(classList)
    bestFeature = maxEntropy(featureMatrix)
    bestFeatureLabel = labels[bestFeature]
    featureLabels.append(bestFeatureLabel)
    dTree = {bestFeatureLabel:{}}
    del(labels[bestFeature])
    #get the bestFeature value form featureMatrix
    featureValue = [row[bestFeature] for row in featureMatrix]
    #ignore the repeat value
    uniqueValues = set(featureValue)
    #traverse the feature value to create the tree
    for value in uniqueValues:
        dTree[bestFeatureLabel][value] = createTree(splitDataSet(featureMatrix, bestFeature, value), labels, featureLabels)
    return dTree

'''
Function : getNumLeaf(dTree)
    Description : to get number of lead in dTree
    Args : dTree
    Rets : numLeafs
'''
def getNumLeaf(dTree):
    numLeafs = 0
    firstStr = next(iter(dTree))
    secondDict = dTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeaf(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

'''
Function : getTreeDepth(dTree)
    Description : to get the depth of dTree
    Args : dTree
    Rets : depth
'''
def getTreeDepth(dTree):
    depth = 0.0
    firstStr = next(iter(dTree))
    secondDict = dTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > depth:
            depth = thisDepth
    return depth

'''
Function : plotNode(nodeText, nodePositon, arrowPosition, nodeType)
    Description : to set dTree node property
    Args : nodeText
            textPosition
            arrowPosition
            nodtType
    Rets : None
'''
def plotNode(nodeText, textPosition, arrowPosition, nodeType):
    arrow_args = dict(arrowstyle = '<-')
    createPlot.ax1.annotate(nodeText, xy = arrowPosition, xycoords = 'axes fraction',
                             xytext = textPosition, textcoords = 'axes fraction',
                              va = 'center', ha = 'center',bbox = nodeType, arrowprops = arrow_args)
'''
Function : plotMidText(startPosition, endPosition, text)
    Description : to set position for text
    Args : startPosition
            endPosition
            text
    Rets : None
'''
def plotMidText(startPosition, endPosition, text):
    xmid = (endPosition[0] - startPosition[0]) / 2.0 + startPosition[0]
    ymid = (endPosition[1] - startPosition[1]) / 2.0 + startPosition[1]
    createPlot.ax1.text(xmid, ymid, text, va = 'center', ha = 'center', rotation = 45)

'''
Function : plotTree(dTree, arrowPosition, text)
    Description : to plot the dTree
    Args : dTree
            arrowPosition
            text
    Rets : None
'''
def plotTree(dTree, arrowPosition, text):
    dNode = dict(boxstyle = 'sawtooth', fc = '0.8')
    leafNode = dict(boxstyle = 'round4', fc = '0.8')
    numLeafs = getNumLeaf(dTree)
    depth = getTreeDepth(dTree)
    firstStr = next(iter(dTree))
    startPosition = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(startPosition, arrowPosition, text)
    plotNode(firstStr, startPosition, arrowPosition, dNode)
    secondDict = dTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():                               
        if type(secondDict[key]).__name__=='dict':                                           
            plotTree(secondDict[key],startPosition,str(key))                                        
        else:                                                                               
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), startPosition, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), startPosition, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

'''
Function : createPlot(dTree)
    Description : to use plt plot dTree figure
    Args : dTree
    Rets : None
'''
def createPlot(dTree):
    fig = plt.figure(1, facecolor='white')                                                   
    fig.clf()                                                                                
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)                                
    plotTree.totalW = float(getNumLeaf(dTree))                                          
    plotTree.totalD = float(getTreeDepth(dTree))                                        
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;                           
    plotTree(dTree, (0.5,1.0), '')
    plt.show()       


    
    
if __name__ == '__main__':
    featureMatrix, labels = createDataSet()
    #print('maxEntropy feature is : ' + str(maxEntropy(featureMatrix)))
    featureLabels = []
    dTree = createTree(featureMatrix, labels, featureLabels)
    print(dTree)
    createPlot(dTree)

  结果如下:

参考博客: https://cuijiahua.com/blog/2017/11/ml_2_decision_tree_1.html

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值