Python-常用机器学习算法-决策树

def chooseBestFeatureToSplit(dataSet):

    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels

    baseEntropy = calcShannonEnt(dataSet)

    bestInfoGain = 0.0; bestFeature = -1

    for i in range(numFeatures):        #iterate over all the features

        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature

        uniqueVals = set(featList)       #get a set of unique values

        newEntropy = 0.0

        for value in uniqueVals:

            subDataSet = splitDataSet(dataSet, i, value)

            prob = len(subDataSet)/float(len(dataSet))

            newEntropy += prob * calcShannonEnt(subDataSet)     

        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy

        if (infoGain > bestInfoGain):       #compare this to the best gain so far

            bestInfoGain = infoGain         #if better than current best, set to best

            bestFeature = i

    return bestFeature         #returns an integer
def createTree(dataSet,labels):

    classList = [example[-1] for example in dataSet]

    if classList.count(classList[0]) == len(classList): 

        return classList[0]#stop splitting when all of the classes are equal

    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet

        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)

    bestFeatLabel = labels[bestFeat]

    myTree = {bestFeatLabel:{}}

    del(labels[bestFeat])

    featValues = [example[bestFeat] for example in dataSet]

    uniqueVals = set(featValues)

    for value in uniqueVals:

        subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels

        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)

    return myTree                            

# myDat,labels = createDataSet()

# mytree = createTree(myDat,labels)

# print mytree

def classify(inputTree,featLabels,testVec):

    firstStr = inputTree.keys()[0]

    secondDict = inputTree[firstStr]

    featIndex = featLabels.index(firstStr)

    key = testVec[featIndex]

    valueOfFeat = secondDict[key]

    if isinstance(valueOfFeat, dict): 

        classLabel = classify(valueOfFeat, featLabels, testVec)

    else: classLabel = valueOfFeat

    return classLabel

# myDat,labels = createDataSet()

# mytree = createTree(myDat,labels)

# print classify(mytree,labels,[1,0])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值