分类算法:ID3与C4.5及CART

原理

  • ID3算法的介绍网上有很多,它是通过选择能获得最大信息增益的属性来构建决策树。
  • C4.5是通过选择能获得最大信息增益率的属性来构建决策树。
  • CART用于观察值和输出值都是连续的值的情况,它可以通过选择则最优划分点来做分类;也可以通过将最优划分点改成线性函数(使每次划分时,点均匀分布在函数两侧)来做预测
  • 要理解信息熵,先要理解熵。熵:当能量均匀分布在物体中时,熵最高。当能量不均匀时,熵最小。带入信息-熵中讲,就是混杂的信息越多,即越混乱说明信息量很大,熵也很大;当信息被提纯后,越来越有序,则信息量很小(因为都是同类),熵也很小。
  • ID3中最大信息增益就是熵增最大,即使信息摆放更有序的方向。然而使信息摆放更有序的方向往往是选会择有最多取值的那个属性,比如性别和年龄而言更趋向于选有(青,中,老)三个取值的年龄属性。因此在C4.5算法中引入了信息增益率,即用ID3信息增益/信息在当前属性上的不确定度。下面会贴出公式,一看就懂。

信息熵的计算公式

I(s1,s2,...,sm)=i=1mpilog2pi

ID3中信息增益的计算公式

Gain(A)=I(s1,s2,...,sm)E(A)

E(A)=j=1vs1j+s2j+s3j+...+smjI(s1j,s2j,s3j,...,smj)

即(按属性A的取值分类前的信息熵)-(按属性A的取值分类后的各子集的信息熵的加权平均值)

C4.5中信息增益率的计算公式

GainRatio(S,A)=Gain(S,A)SplitInfo(S,A)

SplitInfo(S,A)=i=1c|si||S|log2(|si||S|)

SplitInfo(S,A)就是用来衡量属性A内部的混乱度,c是属性A的取值个数, si 是A按属性取值划分后的一个子集,||是求集合大小。


实例

  • 导入训练/测试数据,并作数据集的预处理
  • 构建决策树
    • loop1
    • 计算每个属性信息增益
    • 返回最优特征
    • 划分子集
    • 删除最优特征属性
    • 遍历每个子集
    • loop1
  • 持久化树模型
  • 从文件导入树模型
  • 决策

代码

#MyID3.py
#-*-coding:utf-8-*-
from numpy import *
import math
import copy
import cPickle as pickle
class ID3DTree(object):
    def __init__(self): #construct function
        self.tree={}    #built tree
        self.dataSet=[] #the dataSet
        self.labels=[]  #the classes

    def loadDataSet(self,path,labels):  #import data function
        recordlist=[]
        fp=open(path,"rb")
        content=fp.read()
        fp.close()
        rowlist=content.splitlines()
        recordlist=[row.split("\t") for row in rowlist if row.strip()]
        self.dataSet=recordlist
        self.labels=labels

    def train(self):    #run Decision Tree function
        labels=copy.deepcopy(self.labels)
        self.tree=self.buildTree(self.dataSet,labels)

    def buildTree(self,dataSet,labels): #building Decision Tree,the most important function
        cateList=[data[-1] for data in dataSet] #遍历每行取最后一列,默认抽取最后一列特征用来做类别的判断
        if cateList.count(cateList[0])==len(cateList):  #若子集类别只有一种,则返回该类别
            return cateList[0]
        if len(dataSet[0])==1:  #若子集只有一列,还需判断该列是否纯净,若不纯净则返回该子集中占比最大的类别
            return self.maxCate(cateList)
        #alogrithm core
        bestFeat=self.getBestFeat(dataSet)  #获取数据集的最优特征列的下标
        bestFeatLabel=labels[bestFeat]      #根据下标在labels里找到其对应的name
        tree={bestFeatLabel:{}} #bestFeatLabel作为根节点 {'root':{0:'leaf node',1:{'level2':{0:'leaf node',1:'leaf node'}},2:{'level2':{0:'leaf node',1:'leaf node'}}}}
        del(labels[bestFeat])
        #抽取最优特征列向量
        uniqueVals=set([data[bestFeat] for data in dataSet])    #去掉该特征列里的重复值
        for value in uniqueVals:
            subLabels=labels[:] #用删除上层特征列的特征集做子集的特征列集合
            splitDataSet=self.splitDataSet(dataSet,bestFeat,value)
            subTree=self.buildTree(splitDataSet,subLabels)  #递归构建子树
            tree[bestFeatLabel][value]=subTree  #(回溯二层)tree['年龄']['青']=(回溯一层)tree['学生']['是']=(递归到底)tree['买'] 深度优先遍历
        return tree

    def maxCate(self,cateList): #当最后只剩一列特征列但类别仍不纯净时
        items=dict([(cateList.count(i),i) for i in cateList])
        return items[max(items.keys())]

    def getBestFeat(self,dataSet):
        #计算特征向量维度,其中最后一列用于类别标签(买,不买),因此要减去
        numFeatures=len(dataSet[0])-1   #用于决策的特征数目
        baseEntropy=self.computeEntropy(dataSet)    #计算未细分时,当前层数据集的熵
        bestInfoGain=0.0    #初始化信息熵增益
        bestFeature=-1      #初始化最优特征列
        #遍历各特征列,计算信息熵,并计算信息熵增益
        for i in xrange(numFeatures):
            uniqueVals=set([data[i] for data in dataSet])   #保存特征列的值有几种
            newEntropy=0.0  #记录按特征列划分后的子数据集的信息熵
            for value in uniqueVals:
                subDataSet=self.splitDataSet(dataSet,i,value)
                prob=len(subDataSet)/float(len(dataSet))    #计算特征列里个value占当前dataSet的比率
                newEntropy+=prob*self.computeEntropy(subDataSet)    #遍历计算特征为i,值为value的子集的熵,最后加权平均
            infoGain=baseEntropy-newEntropy #保存按特征列i划分时的信息增益
            if(infoGain>bestInfoGain):  #记录下最优特征列
                bestInfoGain=infoGain
                bestFeature=i
        return bestFeature

    def getBestFeat4_5(self,dataSet): #c4.5算法,按信息增益率选取特征,避免信息增益选择特征时偏向于选取特征值个数较多的情况(特征值越多更利于混乱的减少)
        Num_Feats=len(dataSet[0][:-1])  #取dataSet第一行,下标从第一个到最后一个(不包括)的子数组求长度
        totality=len(dataSet)
        BaseEntropy=self.computeEntropy(dataSet)    #计算当前dataSet的熵
        for f in xrange(Num_Feats):
            featList=[feat[f] for feat in dataSet]  #遍历得到一列特征


    def computeEntropy(self,dataSet):   #计算信息混乱度,越混乱越高
        datalen=float(len(dataSet))
        cateList=[data[-1] for data in dataSet] #提取要分类的特征列(买,不买)
        items=dict([(i,cateList.count(i)) for i in cateList])
        infoEntropy=0.0
        for key in items:
            prob=float(items[key])/datalen
            infoEntropy-=prob*math.log(prob,2)
        return infoEntropy

    def splitDataSet(self,dataSet,axis,value):  #dataSet数据集,axis特征列下标,value特征列的取值之一
        rtnList=[]
        for featVec in dataSet:
            if(featVec[axis]==value):
                rFeatVec=featVec[:axis] #list操作,装入0~axis-1间的元素
                rFeatVec.extend(featVec[axis+1:])   #list操作,装入axis+1,length-1间的元素
                rtnList.append(rFeatVec)
        return rtnList

    #序列化函数
    def storeTree(self,inputTree,filename):
        fw=open(filename,'w')
        pickle.dump(inputTree,fw)
        fw.close()

    def grabTree(self,filename):
        fr=open(filename)
        return pickle.load(fr)

    #决策函数
    def predict(self,inputTree,featLabels,testVec):
        root=inputTree.keys()[0]    #得到树的根节点对应的特征名字
        secondDict=inputTree[root]  #得到树根对应各个取值下的子树或节点(节点即最终分类)
        featIndex=featLabels.index(root)    #得到特征名对应的label下标
        key=testVec[featIndex]  #得到测试集在该特征的取值
        valueOfFeat=secondDict[key] #根据key值取得子树或最终类别
        if isinstance(valueOfFeat,dict):    #判断valueOfFeat是不是字典类,即使不是子树
            classLabel=self.predict(valueOfFeat,featLabels,testVec) #递归分类,将子树作为树传下去
        else: classLabel=valueOfFeat
        return classLabel

#treePlotter.py
# -*- coding: utf-8 -*-
'''
Created on 2015年7月27日

@author: pcithhb
'''
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

#获取树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

#绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#绘制连接线
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

#绘制树结构
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#创建决策树图形
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()
#MyID3_Main.py
#-*-coding:utf-8 -*-
from numpy import *
from MyID3 import *
import treePlotter as tp
dtree=ID3DTree()
dtree.loadDataSet("C:\Users\MCGG\Documents\python\dataset.dat",["age","revenue","student","credit"])
dtree.train()
tp.createPlot(dtree.tree)
print dtree.tree

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值