决策树ID3算法实现

如有问题,欢迎指出。

 

 

决策树模型

决策树(decision tree)是一种基本的分类与回归方法。我这里主要只讨论用于分类的决策树。在分类问题里面,决策树是根据样本的特征进行分类,模型可以认为是if-then规则的集合,举个例子,就好像是在给定的区间里面猜数字,每猜一次就会告诉你大了还是小了,然后根据这个进一步判断,不停的递归,最终找到那个数字。决策树学习的步骤通常就三个,特征选择、决策树生产和剪枝(防止过拟合)。而决策树生成也有三个方法,ID3、C4.5和CART,分别对应着不同的特征选择方式。

定义:分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点和有向边组成。

决策树有两种节点,叶节点和非叶节点(内部节点)。非叶节点可以理解为一个特征就好比是性别是男是女,或者年龄大于30岁还是小于30岁等等,而叶节点就表示一个类别。用决策数分类其实就是从根节点开始,根据输入的x的特征进行划分,不断的递归往叶子节点走,直到走到某个叶子节点,那么就代表着x是属于那一类的。而且,从根走到叶子节点的路径有且唯一。因为决策树是可以看成一个if-then规则的集合,而这个规则是互斥且完备的(因为我们制定规则的时候不会模棱两可)。这就表明了每一个实例都被一条路径或者是一条规则覆盖了,并且只有一条。

 

ID3算法

ID3算法的核心就是在决策树各个节点上应用信息增益准则进行选取特征,然后递归地构建决策树。

首先,先解释一下什么是信息增益准则。

定义(信息增益):特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的条件经验熵H(D|A)之差,即

                                                                                                                                                                        g(D,A)=H(D) - H(D|A)

一般地,熵H(Y)与条件熵H(Y|X)之差称为互信息。决策树学习中的信息增益等价于训练数据集中类与特征的互信息。这个我的理解是,D所包含的信息量也就是熵是H(D),而当有A特征之后D的信息量变为了H(D|A),这两个相减就可以看成A这个特征所带来的信息量(假设你们对熵有一定的了解,熵表示信息量的多少....),也就是信息增加了多少。信息增益主要依赖于特征,不同的特征往往具有不同的信息增益,所以信息增益大的特征具有更强的分类能力。

 

ID3算法:

输入:训练数据集D,特征集A,阈值\varepsilon ;

输出:决策树T。

  1. 若D中所有实例属于同一类C_k,则T为单结点树,并将类C_k作为该结点的类标记,返回T;
  2. A=\varnothing,则T为单结点树,并将D中实例数最大的类C_k作为该结点的类标记,返回T;
  3. 否则,计算A中各特征对D的信息增益,选择信息增益最大的特征A_g
  4. 如果A_g的信息增益小于阈值\varepsilon ,则置T为单结点树,并将D中实例数最大的类C_k作为该结点的类标记,返回T;
  5. 否则,对A_g的每一可能值a_i,依A_g=a_i将D分割为若干非空子集D_i,将D_i中实例数最大的类作为标记,构建子结点,由结点及其子结点构成树T,返回T;
  6. 对第i个子结点,以D_i为训练集,以A-{A_g}为特征集,递归地调用步骤(1)~(5),得到子树T_i,返回T_i

以下是根据机器学习实战中代码实现ID3(如果要实现C4.5就是改写一下里面选特征的函数)

import numpy as np
import matplotlib.pyplot as plt
import operator


def calEnt(datasets):
    numEntries = len(datasets)
    labelCounts = {}
    for featVec in datasets:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    Ent = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        Ent -= prob * np.log2(prob)
    return Ent

def splitDataset(dataset, axis, value):
    retDataset = []
    for featVec in dataset:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataset.append(reducedFeatVec)
    return retDataset


def chooseBestFeatureToSplit(dataset):
    # 计算列数
    numFeatures = len(dataset[0]) - 1
    # 计算H(D)
    baseEntropy = calEnt(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataset]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataset(dataset, i, value)
            prob = len(subDataSet) / float(len(dataset))
            newEntropy += prob * calEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataset, labels):
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataset]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataset(dataset, bestFeat, value), subLabels)
    return myTree

"""
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords="axes fraction", xytext=centerPt, textcoords="axes fraction",
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def createPlot():
    fig = plt.figure(1, facecolor="white")
    fig.clf()
    createPlot().ax1 = plt.subplot(111, frameon=False)
    plotNode('决策结点', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('叶结点', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()
"""


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0
    fitstStr = myTree.keys()[0]
    secondDict = myTree[fitstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


if __name__ == '__main__':

    datasets = [
        ["青年", "否", "否", "一般", "否"],
        ["青年", "否", "否", "好", "否"],
        ["青年", "是", "否", "好", "是"],
        ["青年", "是", "是", "一般", "是"],
        ["青年", "否", "否", "一般", "否"],
        ["中年", "否", "否", "一般", "否"],
        ["中年", "否", "否", "好", "否"],
        ["中年", "是", "是", "好", "是"],
        ["中年", "否", "是", "非常好", "是"],
        ["中年", "否", "是", "非常好", "是"],
        ["老年", "否", "是", "非常好", "是"],
        ["老年", "否", "是", "好", "是"],
        ["老年", "是", "否", "好", "是"],
        ["老年", "是", "否", "非常好", "是"],
        ["老年", "否", "否", "一般", "否"],
    ]
    labels = ["年龄", "有工作", "有自己的房子", "信贷情况"]
    """
    datasets = pd.DataFrame(datasets, columns=["年龄", "有工作", "有自己的房子", "信贷情况", "类别"])

    for i in range(datasets.shape[1]):
        listUniq = datasets.iloc[:, i].unique()
        print(listUniq)
        for j in range(len(listUniq)):
            datasets.iloc[:, i] = datasets.iloc[:, i].apply(lambda x: j if x == listUniq[j] else x)
    print("dataset:\n", datasets)
    datasets = np.array(datasets)
    print("dataset:\n", datasets)
    """
    # print(chooseBestFeatureToSplit(datasets))
    # decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    # leafNode = dict(boxstyle="round4", fc="0.8")
    # arrow_args = dict(arrowstyle="<-")
    myTree = createTree(datasets, labels)
    print("myTree:\n", myTree)

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值