决策树(Decision Trees)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/universe_ant/article/details/52607126

你是否玩过二十个问题的游戏,游戏的规则很简单:参与游戏的一方在脑海里想某个事物,其他参与者向他提问题,只允许提20个问题,问题的答案也只能用对或错回答。问问题的人通过推断分解,逐步缩小带猜测事物的范围。决策树的工作原理与20个问题类似,用户输入一系列数据,然后给出游戏的答案。

下图所示的流程图就是一个决策树,正方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),表示已经得出结论,可以终止运行。从判断模块引出的左右箭头称作分支(branch),它可以到达另一个判断模块或者终止模块。下图其实构造了一个假想的邮件分类系统,首先检测发送的邮件域名地址,如果地址为myEmployer.com,则将其放在分类“无聊时阅读的邮件”中;如果邮件不是来自这个域名,则检查邮件内容是否包括单词"hockey"(曲棍球),如果包含则将邮件归类到“需要及时处理的朋友邮件”,如果不包含则将邮件归类到“无需阅读的垃圾邮件”。

决策树举例

在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找到这个决定性的特征,划分出最好的结果,我们必须对每个特征进行评估。完成测试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支上。如果某个分支下的数据属于同一类型,则当前分支下的数据已经正确分为一类,无需进一步对数据集进行分割。如果数据子集内的数据不属于同一个类型,则需要继续划分数据子集。如何划分数据子集的算法和划分原始数据集的方法相同,知道所有具有相同类型的数据均在一个数据子集内。

决策树的一般流程:

(1)收集数据:可以使用任何方法。

(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。

(3)分析数据:可以使用任何方法,构造树完成后,我们应该检查是否符合预期。

(4)训练算法:构造树的数据结构。

(5)测试算法:使用经验树计算错误率。

(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

构造决策树的伪代码:

检测数据集中的每个子项是否属于同一分类:

if so: return 类标签;

else:

寻找划分数据集的最好特征

划分数据集

创建分支节点

for 每个划分的子集

递归调用函数并增加返回结果到分支节点中

return 分支节点

示例:

下表的数据包含5个海洋生物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。我们可以将这些动物分成两类:鱼类和非鱼类。现在我们想要决定依据第一个特征还是第二个特征划分数据。在回答这个问题之前,我们必须采用量化的方法判断如何划分数据。这里就要引入信息增益的概念。

海洋生物数据示例表格

信息增益

划分数据集的大原则是:将无序的数据变得更加有序。我们可以使用多种方法划分数据集,但是每种方法都有自己的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据之前使用信息论量化度量信息的内容。

在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。

在可以评测哪种数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德·香农。

熵定义为信息的期望值,在明晰这个概念之前,我们必须知道信息的定义。如果待分类的事物可能划分在多个分类中,则符号xi信息定义为-l(xi)

信息

其中p(xi)是选择该分类的概率。

为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,通过下面的公式得到:

信息期望值

其中n是分类的数目。

以下代码中方法chooseBestFeatureToSplit()可以计算最佳信息增益属性:

# function to calculate the Shannon entropy of dataset
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    # create dictionary of all possible classes
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    # logarithm base 2
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

# dataset splitting on a given feature
def splitDataSet(dataSet, axis, value):
    """
    params:
        dataSet: the dataset we will split
        axis: the feature we will split on
        value: the value of the feature to return
    """
    # create separate list
    retDataSet = []
    # cut out the feature split on
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

# choosing the best feature to split on
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    # calculate the Shannon entropy of the whole dataset
    # before any splitting has occured
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        # create unique list of class labels
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        
        newEntropy = 0.0
        # calculate entropy for each split
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        # find the best information gain
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

按照上面提供的决策树伪代码,我们可以写出如下代码:

def majorityCnt(classList):
    """
    If our dataset has run out of attributes but the class labels
    are not all the same, we must decide what to call that leaf node.
    In this situation, we will take a majority vote.
    """
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(),
                              key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

# tree building code
def createTree(dataSet, labels):
    """
    params:
        dataSet: the dataSet we will use to create decision-tree
        labels: the list of labels contains a label for each of
                the features in the dataset
    """
    classList = [example[-1] for example in dataSet]
    # stop when all classes are equal
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # when no more features, return majority
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    # get list of unique values
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree( \
                        splitDataSet(dataSet, bestFeat, \
                                value), subLabels)
    return myTree

最后可以根据决策树进行分类,具体分类代码如下:

# classification function for an existing decision tree
def classify(inputTree, featLabels, testVec):
	firstStr = inputTree.keys()[0]
	secondDict = inputTree[firstStr]
	# translate label string to index
	featIndex = featLabels.index(firstStr)
	for key in secondDict.keys():
		if testVec[featIndex] == key:
			if type(secondDict[key]).__name__ == 'dict':
				classLabel = classify(secondDict[key], featLabels, testVec)
			else:
				classLabel = secondDict[key]
	return classLabel

当然,我们知道构建决策树是一件很耗时的任务,所以我们可以将已经构建的决策树进行序列化,这样很方便下次使用,下面代码使用python的模块pickle进行序列化和反序列化操作:

# methods for persisting the decision tree with pickle
def storeTree(inputTree, filename):
	import pickle
	fw = open(filename, 'w')
	pickle.dump(inputTree, fw)
	fw.close()

def grabTree(filename):
	import pickle
	fr = open(filename)
	return pickle.load(fr)


展开阅读全文

没有更多推荐了,返回首页