[python for ML] Decision tree

tree.py

from math import log

# function to calculate entropy
## a data set with the last column being class label
def calEntropy(dataset):
    # number of data set
    numData = len(dataset)

    # create a dictionary to contain number of each group
    labelCount = {}

    # fill key-value in dictionary labelCount
    for sam in dataset:
        # label of example sam
        label = sam[-1]
        # check if label is the key of classCount
        if label not in labelCount.keys():
            # if not create one
            labelCount[label] = 0
        # add value corresponding to label by 1
        labelCount[label] += 1

    # initialize entropy
    ent = 0.0
    for x in labelCount:
        # calculate probability of each group
        prob = float(labelCount[x])/numData
        # mean of information of each group
        ent -= prob*log(prob,2)
    # return final entropy
    return ent

# function to create data set
def creatData():
    # create data set
    dataset = [[1,1,"yes"],
               [1,1,"yes"],
               [1,0,"no"],
               [0,1,"no"],
               [0,1,"no"]]
    # create name of each feature
    label = ["no suifacing", "flippers"]
    # return data set and name of feature
    return dataset,label


# function to split data set
## axis: means which feature will be used
## value: value to split data set by axis
def splitData(dataset, axis, value):
    # make a new list to contain the splitted data set
    setDataset = []

    # take each example
    for x in dataset:
        # check the value of feature axis if or not equivalent to value
        if x[axis]==value:
            # if so, delete feature axis from this example
            ## take features before axis 1:(axis-1)
            reducedX = x[:axis]
            ## take features after axis (axis+1):last feature
            reducedX.extend(x[axis+1:])
            # add this new example to the reset data set
            setDataset.append(reducedX)
    # return the reset data set by value of feature axis
    return setDataset

# function to choose best feature to split data 
def chooseFeaturetoSplit(dataset):
    # number of sample
    numSam = len(dataset)
    # number of features
    numVar = len(dataset[0])-1

    # base entropy of full data set
    primEntropy = calEntropy(dataset)

    # initialize value of best feature and best information gain 
    bestFeature = -1
    bestInforGain = 0

    # for each feature, do the following code
    for ivar in range(numVar):
        # use set comprehension to get the unique value of this feature
        unival = {example[ivar] for example in dataset}
        # initialize newEntropy
        newEntropy = 0.0
        # for each unique value of this feature
        for j in unival:
            # get the sub data set of which this feature equals to j
            subData = splitData(dataset, ivar, j)
            # calculate probability of sub data set
            prob = len(subData)/float(numSam)
            # compute newEntropy and to calculate the mean entropy if using this feature
            newEntropy += prob*calEntropy(subData)
        # compute information gain of this feature
        InforGain = primEntropy - newEntropy
        # check if or not information gain of this feature is bigger than the original one
        if(InforGain > bestInforGain):
            # if so, make best feature be this feature
            bestFeature = ivar
            # and set best information gain to be the one of this feature
            bestInforGain = InforGain       
    # return best feature
    return bestFeature

import operator
# function of classification by major votes
## classList is a list contained class information
def majorCnt(classList):
    # make a dictionary to contain feature:number 
    classCount = {}
    # for each class in classList
    for vote in classList:
        # check if it is a key of classCount
        # if not, set a new key of classCount  
        if vote not in classCount.keys():classCount[vote]=0
        # add value of this key by 1
        classCount[vote]+=1
    # sort a dictionary by value(decreasing)
    ## dict.items() returns a list of tuples containing key-value information of dict
    ## operator.itergetter(1) return the 1th item, which is the value of dict
    ## reverse=True means decreasing; reverse=False means increasing
    sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    # return major vote or the final class
    return sortCount[0][0]


# function to create tree
## labels means names of feature contained in the data set
def createTree(dataset, labels):
    # create a new list to contain class labels of all examples
    classList = [example[-1] for example in dataset]
    # stop splitting data if all data in this node are in the same group
    if(classList.count(classList[0]) == len(classList)):
        return classList[0]
    # stop splitting  data if there is no more feature to use
    if len(dataset[0])==1:
        return majorCnt(classList)

    # choose the best feature to split data
    bestFeat = chooseFeaturetoSplit(dataset)
    # give name of best feature
    bestLabel = labels[bestFeat]

    # set empty dictionary to contain information of tree
    mytree = {bestLabel: {}}
    # delete best feature to create new labels 
    del(labels[bestFeat])
    # take all example of best feature, to be a set(not a list)
    featurValue = {example[bestFeat] for example in dataset}
    # split data and make tree more deep
    for val in featurValue:
        # using all labels have not used now
        subLabel = labels[:]
        # create deeper trees of each node corresponding to unique value of best feature
        mytree[bestLabel][val] = createTree(splitData(dataset, bestFeat, val), subLabel)
    # return information of tree
    return mytree

treePlotter.py

import matplotlib.pyplot as plt

# a dictionary containing properties of box around the annotation
## sawtooth: sawtooth(锯齿形的) rectangular
## fc: color 
decisionNode = dict(boxstyle = "sawtooth", fc="0.8")
## round4: Oval
leafNode = dict(boxstyle = "round4", fc="w")
# the arrow style
arrow_args = dict(arrowstyle = "<-")

# plot node or adding annotation in the tree
## nodetxt: txt information as annotation
## centerPt: position of child node
## parentPt: position of parent node
## nodeType: properties of node
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # The annotate() function in the pyplot module (or annotate method of the Axes class) is used to draw an arrow connecting two points on the plot.
    ## axes fraction: 0,0 is lower left of axes and 1,1 is upper right (can be seen as the usual coordinates)
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords="axes fraction", xytext=centerPt, textcoords="axes fraction",
                            ## bbox gives a box around the text
                            ## arrowprops: properties of arrow
                            ## va, ha: v-vertical, horizontal
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot(): 
    # begin to create a new figure: create a wall to plot
    fig=plt.figure(1, facecolor="white")
    # clear wall
    fig.clf()
    # create a new figure
    ## 1,1,1: nrow, ncol, number of figure to plot
    ## frameon: If True, the figure patch will be colored, if False, the figure background will be transparent
    createPlot.ax1=plt.subplot(111, frameon=False)
    # plot parent node
    plotNode("a decision node", (0.5,0.1), (0.1,0.5), decisionNode)
    # plot child node
    plotNode("a leaf node", (0.8,0.1), (0.3,0.8), leafNode)
    # show the figure
    plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值