决策树预测隐形眼镜的类型(python机器学习记录)

使用决策树预测隐形眼镜类型

tree.py(创建树的代码)
# -*- coding: UTF-8 -*-
from math import log
import operator

#计算数据集的香农熵     熵越高数据越乱
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
            prob = float(labelCounts[key]) / numEntries
            shannonEnt -= prob * log(prob, 2)
    return shannonEnt

#划分数据集
def splitDataSet(dataSet,axis,value):
    retDataSet=[]

    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
        # print retDataSet
    return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureTosplit(dataSet):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0;
    bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)
        newEntorpy=0.0
        for value in uniqueVals:
            subDataSet =splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntorpy+=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntorpy
        if(infoGain>bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return  bestFeature
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

#创建树
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]#stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureTosplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

treePlotter.py(绘制树的代码)

# -*- coding: UTF-8 -*-
import sys
reload(sys)
sys.setdefaultencoding('gbk')

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

#获取叶节点的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]  # the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key], cntrPt, str(key))  # recursion
        else:  # it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


#主函数
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks
    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

  1. 收集数据:lenses.txt
    young	myope	no	reduced	no lenses
    young	myope	no	normal	soft
    young	myope	yes	reduced	no lenses
    young	myope	yes	normal	hard
    young	hyper	no	reduced	no lenses
    young	hyper	no	normal	soft
    young	hyper	yes	reduced	no lenses
    young	hyper	yes	normal	hard
    pre	myope	no	reduced	no lenses
    pre	myope	no	normal	soft
    pre	myope	yes	reduced	no lenses
    pre	myope	yes	normal	hard
    pre	hyper	no	reduced	no lenses
    pre	hyper	no	normal	soft
    pre	hyper	yes	reduced	no lenses
    pre	hyper	yes	normal	no lenses
    presbyopic	myope	no	reduced	no lenses
    presbyopic	myope	no	normal	no lenses
    presbyopic	myope	yes	reduced	no lenses
    presbyopic	myope	yes	normal	hard
    presbyopic	hyper	no	reduced	no lenses
    presbyopic	hyper	no	normal	soft
    presbyopic	hyper	yes	reduced	no lenses
    presbyopic	hyper	yes	normal	no lenses
    
  2. 准备数据:解析tab键分隔符的
  3. 分析数据:快速检测数据,确保正确的解析数据内容,使用createPlot()函数绘制最终的树形图。
  4. 训练算法:使用createTree()函数
  5. 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
  6. 使用算法:存储数的数据结构,以便下次使用时无需重新构造树
在导入tree.py和treePlotter.py的python交互模式下:
>>>import tree
>>>import treePlotter
>>>fr=open("lenses.txt")
>>>lenses=[inst.strip().split("\t") for inst in fr.readlines()]
>>>lensesLabels=["age","prescript","astigmatic","tearRte"]
>>>lensesTree=tree.createTree(lenses,lensesLabels)
>>>treePlotter.createPlot(lensesTree)

绘制的决策树:



参考书籍:机器学习实战





  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值