python实现决策树分类

基于信息增益的决策树分类是较为常见的一种分类方法,特征属性一般为标称型数据。

原理较为简单,这里不做推导。网上的程序许多是基于python2.x,我在这里将基于python3.6的程序列出来供大家参考。欢迎多多交流!

def create_dataset():
    data_set = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return data_set, labels
 
from math import log
#计算信息增益
def cal_ent(dataset):
    num = len(dataset)
    label_counts = {}
    for vect in dataset:
        temp = vect[-1]
        if temp not in label_counts.keys():
            label_counts[temp] = 0
        label_counts[temp] += 1
    for key, value in label_counts.items():
        label_counts[key] = float(value) / num
    shannon_entropy = 0
    for value in label_counts.values():
        shannon_entropy -= value * log(value, 2)
    return shannon_entropy

 
from calculate_entropy import cal_ent
from split_dataset import split_dataset
#挑选属性,原则为信息增益最大
def select_feature(dataset):
    feature_entropy = {}
    num_1 = len(dataset[0])
    num_2 = len(dataset)
    index = 0
    max_gain = 0
    entropy = cal_ent(dataset)
    for feature_position in range(0, num_1-1):
        feature_gain = 0
        feature_list = [item[feature_position] for item in dataset]
        feature_set_1 = set(feature_list)
        for number in feature_set_1:
            feature_set = split_dataset(dataset, feature_position, number)
            feature_gain += len(feature_set)/num_2*cal_ent(feature_set)
        feature_entropy[feature_position] = entropy - feature_gain
    for value in feature_entropy.values():
        if value > max_gain:
            max_gain = value
            index += 1
    return index-1

 
def split_dataset(dataset, feature_position, feature_value):
    split_result = []
    for vect in dataset:
        temp = vect[feature_position]
        if temp == feature_value:
            reducedfeatvec = vect[:feature_position]
            reducedfeatvec.extend(vect[feature_position+1:])
            split_result.append(reducedfeatvec)
    return split_result

 
from majority import majority
from select_feature import select_feature
from split_dataset import split_dataset
#创建树
def  createtree(dataset,labels):
    classlist=[example[-1] for example in dataset]
    if classlist.count(classlist[0]) == len(classlist):
        return classlist[0]
    if len(dataset[0]) == 1:
        return majority(classlist)
    bestfeature = select_feature(dataset)
    bestfeaturelabel = labels[bestfeature]
    mytree={bestfeaturelabel: {}}
    del(labels[bestfeature])
    feature_value = {example[bestfeature] for example in dataset}
    uniquevals = set(feature_value)
    for value in uniquevals:
        sublabels = labels[:]
        mytree[bestfeaturelabel][value] = createtree(split_dataset(dataset, bestfeature, value),sublabels)
    return mytree
 

import matplotlib.pyplot as pyplot
import operator
#将创建的树可视化
decision_node = dict(boxstyle="sawtooth", fc="0.8")
leafnode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

def getNumLeafs(myTree):
    numLeafs=0 #初始化结点数
    # 下面三行为代码 python3 替换注释的两行代码
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]  # 找到输入的第一个元素,第一个关键词为划分数据集类别的标签
    secondDict = myTree[firstStr]
    #firstStr = myTree.keys()[0]
    #secondDict=myTree[firstStr]
    for key in secondDict.keys(): #测试数据是否为字典形式
        if type(secondDict[key]).__name__ == 'dict': #type判断子结点是否为字典类型
            numLeafs += getNumLeafs(secondDict[key])
            #若子节点也为字典,则也是判断结点,需要递归获取num
        else:
            numLeafs += 1
    return numLeafs  #返回整棵树的结点数

def cal_treedepth(tree):
    maxdepth=0
    first_key = list(tree.keys())
    key_name = first_key[0]
    seconddict = tree[key_name]
    for key in seconddict.keys():
        if type(seconddict[key]).__name__ == 'dict':
            this_depth = 1 + cal_treedepth(seconddict[key])
        else:
            this_depth = 1
        if this_depth >maxdepth:
            maxdepth = this_depth
    return maxdepth

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=0)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  #计算树的宽度  totalW
    depth = cal_treedepth(myTree) #计算树的高度 存储在totalD
    #python3.x修改
    firstSides = list(myTree.keys())#firstStr = myTree.keys()[0]     #the text label for this node should be this
    firstStr = firstSides[0]  # 找到输入的第一个元素
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)#按照叶子结点个数划分x轴
    plotMidText(cntrPt, parentPt, nodeTxt) #标注结点属性
    plotNode(firstStr, cntrPt, parentPt, decision_node)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y方向上的摆放位置 自上而下绘制,因此递减y值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#判断是否为字典 不是则为叶子结点
            plotTree(secondDict[key], cntrPt, str(key))        #递归继续向下找
        else:   #为叶子结点
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW #x方向计算结点坐标
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafnode)#绘制
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))#添加文本信息
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #下次重新调用时恢复y

def createPlot(inTree): 
    fig = pyplot.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = pyplot.subplot(111, frameon=False, **axprops)  # no ticks
    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(cal_treedepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    pyplot.show()
from create_dataset import create_dataset
from createtree import createtree
from tree_plot import *
#主函数
def classify(tree, labels, vect):
    keylist = list(tree.keys())
    keyname = keylist[0]
    position_index = labels.index(keyname)
    secondtree = tree[keyname]
    global temp
    for key in secondtree.keys():
        if key == vect[position_index]:
            if type(secondtree[key]).__name__ == 'dict':
                classify(secondtree[key], labels, vect)
            else:
                temp = secondtree[key]
    return temp

data_set, labels = create_dataset()
vect = [1, 1]
label_list = ['no surfacing', 'flippers']
my_tree = createtree(data_set, labels)
createPlot(my_tree)
print(my_tree)
print(classify(my_tree, label_list, vect))

最后的vect和label_list为测试向量,用于测试分类树。

运行结果如下:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
yes


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值