机器学习实战:决策树(decision Trees)


from numpy import *
from math import log
import operator

def calcShannonEnt(dataSet):
    num=len(dataSet)
    labelCount={}
    for data in dataSet:
        currentLabel = data[-1]
        #if currentLabel not in labelCount.keys():
        #    labelCount[currentLabel]=0
        #labelCount[currentLabel]+=1
        labelCount[currentLabel]=labelCount.get(currentLabel,0)+1
     
    shannonEnt=0.0
    for key in labelCount:
        p=float(labelCount[key])/num
        shannonEnt -= p*log(p,2)
    return shannonEnt
       
def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels
    
def splitDataSet(dataSet, axis, value):
    retDataSet=[]
    for featureVec in dataSet:
        if featureVec[axis]==value:
            temp=featureVec[:axis]
            temp.extend(featureVec[axis+1:])
            retDataSet.append(temp)
    return retDataSet
    
def chooseBestFeatureToSplit(dataSet):
    EntD=calcShannonEnt(dataSet)
    feaNo=len(dataSet[0])-1
    bestFeature=-1
    bestEntD=-1
    for i in range(feaNo):
        feati=[example[i] for example in dataSet]
        uniqueVals=set(feati)
        subEnt=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet, i, value)
            p=len(subDataSet)/float(len(dataSet))
            subEnt+=p*calcShannonEnt(subDataSet)
        newEnt=EntD-subEnt
        if newEnt > bestEntD:
            bestEntD=newEnt
            bestFeature=i
    return bestFeature
    
def majorityCnt(classList):
    classCount={}
    for item in classList:
        classCount[item]=classCount.get(item)+1
    sortedClass=sorted(classCount.iteritems,key=operator.itemgetter(1),reverse=True)
    return sortedClass[0][0]
    
def createTree(dataSet,labels):
    classList=[item[-1] for item in dataSet]
    if len(set(classList))==1:
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeature=chooseBestFeatureToSplit(dataSet)
    bestLabel=labels[bestFeature]
    bFeatureItems=[example[bestFeature] for example in dataSet]
    uniqueVals=set(bFeatureItems)
    trees={bestLabel:{}}
    del(labels[bestFeature])
    for value in uniqueVals:
        subDataSet=splitDataSet(dataSet,bestFeature, value)
        subLabels=labels[:]
        trees[bestLabel][value]=createTree(subDataSet,subLabels)
    return trees

def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0]
    secTree=inputTree[firstStr]
    try:
        featIndex=featLabels.index(firstStr)
    except ValueError:
        print("List does not contain value")
    for key in secTree.keys():
        if testVec[featIndex]==key:
            if type(secTree[key]).__name__ == 'dict':
                result=classify(secTree[key],featLabels,testVec)
            else:
                result=secTree[key]
    return result

def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()
    
def grabTree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)

dataSet,label=createDataSet()
trees=createTree(dataSet,label)
print trees
dataSet,label=createDataSet()
r=classify(trees,label,[0,1])
print r

'''
fr=open('lenses.txt')
lenses=[inst.strip().split('\t') for inst in fr.readlines()]
lenseLabel=['age', 'prescript', 'astigmatic', 'tearRate']
trees=createTree(lenses,lenseLabel)
lenseLabel=['age', 'prescript', 'astigmatic', 'tearRate']
result=classify(trees,lenseLabel,['pre','myope','no','normal'])
print result
'''


  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值