代码注释:机器学习实战第3章 决策树

写在开头的话:在学习《机器学习实战》的过程中发现书中很多代码并没有注释,这对新入门的同学是一个挑战,特此贴出我对代码做出的注释,仅供参考,欢迎指正。

1、trees.py

#coding:gbk
from math import log
import operator


#作用:建立数据集
#输出:数据集,标签名称
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels


#作用:计算香农熵
#输入:数据集列表,最后一列为类
#输出:数据集香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)#数据集中实例总数
    labelCounts = {}#创建字典,表示类标签出现次数
    for featVec in dataSet:
        currentLabel = featVec[-1]#最后一列,即类标签
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0#香农熵
    for key in labelCounts:#对每个类标签来说
        prob = float(labelCounts[key]) / numEntries#标签出现概率
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


#作用:按照给定特征划分数据集,去除axis对应列特征值等于value的值
#输入:待划分的数据集,划分数据集的特征即列数,需要返回的特征的值
#输入:划分后的数据集
def splitDataSet(dataSet, axis, value):
    retDataSet = []#返回列表
    for featVec in dataSet:#对数据集中每一行
        if featVec[axis] == value:#如果相等
            reducedFeatVec = featVec[:axis]#该行和下一行的作用是得到去除featVec[axis]的列表
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)#将去除featVec[axis]的列表添加到返回列表中
    return retDataSet


#作用:得到最好的数据集划分方式
#输入:数据集列表,最后一列为类
#输出:最好的数据集划分方式对应的特征值
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1#dataSet特征数,-1表示最后一列为类别标签
    baseEntropy = calcShannonEnt(dataSet)#dataSet的香农熵
    bestInfoGain = 0.0;#最大信息熵
    bestFeature = -1;#最佳特征值
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]#列表推导式,找到第i个特征对应的属性值,注意是列表,会有多个相同的属性值
        uniqueVals = set(featList)#转换为集合,消除相同的属性值,集合里只能存在不相同的属性值
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))#注意float,不能两个int值相除,只能得int值
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):#如果新特征值拥有更大的信息熵
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


#作用:返回出现最多的分类名称
#输入:分类名称的列表
#输出:出现最多的分类名称
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1#出现频率加1
    sortedClassCount = sorted(classCount.iteritems,#iteritems()表示将classCount以一个迭代器对象返回
                              key = operator.itemgetter(1), reverse = true)#operator.itemgetter(1)表示第2维数据即值,reverse = True表示从大大小排列
    return sortedClassCount


#作用:创建树
#输入:数据集,标签名称
#输出:树的字典形式
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]#列表推导式,得类别标签列表
    #类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):#classList.count(classList[0])表示将计算第一个类别出现的次数
        return classList[0]
    #遍历完所有特征时返回出现次数最多的类别
    #该程序用到了递归,此为递归退出条件
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)#最好的数据集划分方式对应的特征值
    bestFeatLabel = labels[bestFeat]#最好的数据集划分方式对应的特征值对应的类别名称
    myTree = {bestFeatLabel:{}}#建立字典,键为根节点类别名称
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]#列表推导式,得特征值对应值的列表
    uniqueVals = set(featValues)#建立集合
    for value in uniqueVals:
        subLabels = labels[:]#使用新变量代替原始列表
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)#创建子节点
    return myTree


#作用:使用决策树的分类函数
#输入:树的字典形式,分类标签,待分类矢量
#输出:分类标签
def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)#得firstStr在分类标签中的索引
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


#作用:储存决策树
#输入:树的字典形式,文件名字
#输出:无
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


#作用:提取决策树
#输入:文件名字
#输出:树的字典形式
def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)

2、treePlotter.py

#coding:gbk
import matplotlib.pyplot as plt

#定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

#作用:绘制带箭头的注解
#输入:注解,注解坐标,起点坐标,注解类型
#输出:无
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', xytext = centerPt, textcoords = 'axes fraction',
                            va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

#作用:绘制图像
#输入:
#输出:无
def createPlot(inTree):
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()
    axprops = dict(xticks = [], yticks = [])
    createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))#树的宽度
    plotTree.totalD = float(getTreeDepth(inTree))#数的高度
    plotTree.x0ff = -0.5/plotTree.totalW;#根节点x值?
    plotTree.y0ff = 1.0;#根节点y值,为1.0表示放在最高点
    plotTree(inTree, (0.5, 1.0), '')#绘制根节点,0.5表示在x方向的中间,1.0表示在y方向的最上面,''表示为根节点,不用标记子节点属性值
    #plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    #plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

#作用:获取叶节点的数目
#输入:树的字典形式
#输出:叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0#叶节点数目
    firstStr = myTree.keys()[0]#根节点键
    secondDict = myTree[firstStr]#根节点值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

#作用:获取树的层数
#输入:树的字典形式
#输出:树的层数
def getTreeDepth(myTree):
    maxDepth = 0#数的层数
    firstStr = myTree.keys()[0]#根节点键
    secondDict = myTree[firstStr]#根节点值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

#作用:输出树的字典形式
#输入:需要的数
#输出:树的字典形式
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

#作用:在父子节点间填充文本信息
#输入:子节点位置,父节点位置,文本信息
#输出:无
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)#叶节点数目
    depth = getTreeDepth(myTree)#树的层数
    firstStr = myTree.keys()[0]#根节点键
    cntrPt = (plotTree.x0ff + (1.0 + float(numLeafs)) / 2.0 /plotTree.totalW, plotTree.y0ff)
    plotMidText(cntrPt, parentPt, nodeTxt)#绘制文字
    plotNode(firstStr, cntrPt, parentPt, decisionNode)#绘制根节点
    secondDict = myTree[firstStr]
    plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#如果该子节点为字典类型,如是则递归调用
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)
            plotMidText((plotTree.x0ff, plotTree.y0ff), cntrPt, str(key))
    plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalD


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值