决策树算法代码解释

引用数据集获取:

西瓜数据集2.0获取

如何构造一个决策树?
我们使用 createBranch() 方法,如下所示:

def createBranch():
'''
此处运用了迭代的思想。 感兴趣可以搜索 迭代 recursion, 甚至是 dynamic programing。
'''
    检测数据集中的所有数据的分类标签是否相同:
        If so return 类标签
        Else:
            寻找划分数据集的最好特征(划分之后信息熵最小,也就是信息增益最大的特征)
            划分数据集
            创建分支节点
                for 每个划分的子集
                    调用函数 createBranch (创建分支的函数)并增加返回结果到分支节点中
            return 分支节点

 参看西瓜书76页的西瓜数据集2.0

def createDataList():
    dataList = [
        # 1
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 2
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 3
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 4
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 5
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 6
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
        # 7
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
        # 8
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
        # ----------------------------------------------------
        # 9
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
        # 10
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
        # 11
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
        # 12
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
        # 13
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 14
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 15
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
        # 16
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
        # 17
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]
    labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
    return dataList, labels

 

def calcShannonEnt(dataList):
    dataCount = len(dataList)
    labelCounts = {}
    for featVec in dataList:
        currentLabel = featVec[-1]
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1
        
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / dataCount
        
        shannonEnt -= prob * np.math.log(prob, 2)
        
    return shannonEnt

按照给定特征划分数据集

def splitDataList(dataList, index, value):
    retDataList = []
    for featVec in dataList:
        if featVec[index] == value:
            reducedFeatVec = featVec[:index]
            reducedFeatVec.extend(featVec[index + 1:])
            retDataList.append(reducedFeatVec)
    
    return retDataList

选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataList):
    
    numFeatures = len(dataList[0]) - 1
    baseEnt = calcShannonEnt(dataList)
    bestInfoGain, bestFeature = 0.0, -1
    
    for i in range(numFeatures):
        featList = [example[i] for example in dataList]
        uniqueVals = set(featList)
        newEnt = 0.0
        for value in uniqueVals:
            subDataList = splitDataList(dataList, i, value)
            prob = len(subDataList) / float(len(dataList))
            newEnt += prob * calcShannonEnt(subDataList)
        
        infoGain = baseEnt - newEnt
        print('infoGain=', infoGain, 'bestFeature=', i, baseEnt, newEnt)
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    
    return bestFeature

当划分到最后特征不一致时,少数服从多数

def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        classCount[vote] = classCount.get(vote, 0) + 1
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

创建树

def createTree(dataList, labels):
    classList = [example[-1] for example in dataList]
    
    if classList.count(classList[0]) == len(classList) :
        return classList[0]
    
    if len(dataList[0]) == 1:
        return majorityCnt(classList)
    
    bestFeatIdx = chooseBestFeatureToSplit(dataList)
    
    bestFeatLabel = labels[bestFeatIdx]
    
    myTree = {bestFeatLabel : {}}
    
    del(labels[bestFeatIdx])
    
    featValues = [example[bestFeatIdx] for example in dataList]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataList(dataList, bestFeatIdx, value), subLabels)
    
    return myTree

调用

if __name__ == '__main__':
    dataList, labels = createDataList()
    myTree = createTree(dataList, labels)
    print(myTree)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值