决策树(DecisionTree)的白话原理和简单应用

DecisionTree.py

from math import log
#决策树利用了信息论中用熵来表示数据分类的混乱程度,一个集合中Shannon熵越高该集合越混乱
#因此选择划分属性的时候,先计算当前集合的熵,再分别计算利用每个属性划分后集合的熵,
# 最后与当前集合熵相差最大的熵所对应的属性即当前划分属性
#一直最优划分,直到每个节点都为纯节点后或者所有属性都划分完了为止,此时建树完成
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

def splitDataSet(dataSet, axis, value):
    retDataset = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataset.append(reducedFeatVec)
    return retDataset

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i ,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    classCount = {}
    for vote in classCount:
        if vote not in classCount.keys():classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=lambda d:d[1], reverse = True)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    #print(len(classList))
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    print(bestFeat)
    print(bestFeatLabel)
    myTree = {bestFeatLabel:{}}#用字典建立新的结点
    tmplabels = labels[:]
    del(tmplabels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    print(uniqueVals)
    for value in uniqueVals:
        subLabels = tmplabels[:]
        #当前字典结点添加新结点
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key],featLabels,testVec)
            else: classLabel = secondDict[key]
    return classLabel

def creatDataSet():
    dataSet = [
        [1,1,'yes'],
        [1,1,'yes'],
        [1,0,'no'],
        [0,1,'no'],
        [0,1,'no']
    ]
    labels = ['no surfacing','flippers']
    return dataSet, labels

myData, labels = creatDataSet()
print("1",labels)
myTree = createTree(myData, labels)
print("2",labels)
print(myTree) #result = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
test1 = classify(myTree,labels,[1,0])
test2 = classify(myTree,labels,[1,1])
print("test1: ",test1) # test1:  no
print("test2: ",test2) # test2:  yes

简单应用,得到分类

import ch2.DecisionTree as dTree
fr = open("lenses.txt")
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = dTree.createTree(lenses,lensesLabels)
print(lensesTree)

得到决策树后,通过将其序列化,可以在需要使用的时候再将其读入内存,即不需再次训练,节省时间。

#任何对象都可以使用pickle序列化
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

test.py 学习过程中的对于python的一些练习,和一些知识点。

#python语言在函数中传递的是列表的引用,在函数内部对列表对象的修改将会影响该列表对象的整个生存周期

import matplotlib.pyplot as plt
'''
#决策树注解绘图
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,
    xycoords='axes fraction',
    xytext=centerPt, textcoords='axes fraction',
    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False)
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

createPlot()
'''

'''
#测试extend和append的不同
a = [1,2,3]
b = [3,4,5]
a.append(b)
print(a)
#使用append结果为[1, 2, 3, [3, 4, 5]],即将b整个列表当成一个元素添加,添加后有4个元素
c = [1,2,3]
d = [3,4,5]
c.extend(d)
print(c)
#使用extend结果为[1, 2, 3, 3, 4, 5],添加后有6个 元素
'''

'''
#测试dict_keys
#python3中dict的keys(), values(), items()返回的都是迭代器,用list转化为列表可用索引调用得到每个key
d= {'a':{'d':2},'b':1,'c':{}}
print(list(d.keys()))
print(list(d.keys())[0])
'''
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值