决策树

在这里插入图片描述
信息增益
通常信息增益越大,则意味着使用属性a划分所获得的样本集合的综合纯度越大。

在这里插入图片描述
在这里插入图片描述

ID3决策树和CART决策树

from sklearn.datasets import load_iris
import numpy as np
import math
from collections import Counter


class decisionnode:
    def __init__(self, d=None, thre=None, results=None, NH=None, lb=None, rb=None, max_label=None):
        self.d = d   # d表示维度
        self.thre = thre  # thre表示二分时的比较值,将样本集分为2类
        self.results = results  # 最后的叶节点代表的类别
        self.NH = NH  # 存储各节点的样本量与经验熵的乘积,便于剪枝时使用
        self.lb = lb  # desision node,对应于样本在d维的数据小于thre时,树上相对于当前节点的子树上的节点
        self.rb = rb  # desision node,对应于样本在d维的数据大于thre时,树上相对于当前节点的子树上的节点
        self.max_label = max_label  # 记录当前节点包含的label中同类最多的label


def entropy(y):
    '''
    计算信息熵,y为labels
    '''

    if y.size > 1:

        category = list(set(y))
    else:

        category = [y.item()]
        y = [y.item()]

    ent = 0

    for label in category:
        p = len([label_ for label_ in y if label_ == label]) / len(y)
        ent += -p * math.log(p, 2)

    return ent


def Gini(y):
    '''
    计算基尼指数,y为labels
    '''
    category = list(set(y))
    gini = 1

    for label in category:
        p = len([label_ for label_ in y if label_ == label]) / len(y)
        gini += -p * p

    return gini


def GainEnt_max(X, y, d):
    '''
    计算选择属性attr的最大信息增益,X为样本集,y为label,d为一个维度,type为int
    '''
    ent_X = entropy(y)
    X_attr = X[:, d]
    X_attr = list(set(X_attr))
    X_attr = sorted(X_attr)
    Gain = 0
    thre = 0

    for i in range(len(X_attr) - 1):
        thre_temp = (X_attr[i] + X_attr[i + 1]) / 2
        y_small_index = [i_arg for i_arg in range(
            len(X[:, d])) if X[i_arg, d] <= thre_temp]
        y_big_index = [i_arg for i_arg in range(
            len(X[:, d])) if X[i_arg, d] > thre_temp]
        y_small = y[y_small_index]
        y_big = y[y_big_index]

        Gain_temp = ent_X - (len(y_small) / len(y)) * \
            entropy(y_small) - (len(y_big) / len(y)) * entropy(y_big)
        '''
        intrinsic_value = -(len(y_small) / len(y)) * math.log(len(y_small) /
                                                              len(y), 2) - (len(y_big) / len(y)) * math.log(len(y_big) / len(y), 2)
        Gain_temp = Gain_temp / intrinsic_value
        '''
        # print(Gain_temp)
        if Gain < Gain_temp:
            Gain = Gain_temp
            thre = thre_temp
    return Gain, thre


def Gini_index_min(X, y, d):
    '''
    计算选择属性attr的最小基尼指数,X为样本集,y为label,d为一个维度,type为int
    '''

    X = X.reshape(-1, len(X.T))
    X_attr = X[:, d]
    X_attr = list(set(X_attr))
    X_attr = sorted(X_attr)
    Gini_index = 1
    thre = 0

    for i in range(len(X_attr) - 1):
        thre_temp = (X_attr[i] + X_attr[i + 1]) / 2
        y_small_index = [i_arg for i_arg in range(
            len(X[:, d])) if X[i_arg, d] <= thre_temp]

        y_big_index = [i_arg for i_arg in range(
            len(X[:, d])) if X[i_arg, d] > thre_temp]
        y_small = y[y_small_index]
        y_big = y[y_big_index]

        Gini_index_temp = (len(y_small) / len(y)) * \
            Gini(y_small) + (len(y_big) / len(y)) * Gini(y_big)
        if Gini_index > Gini_index_temp:
            Gini_index = Gini_index_temp
            thre = thre_temp
    return Gini_index, thre


def attribute_based_on_GainEnt(X, y):
    '''
    基于信息增益选择最优属性,X为样本集,y为label
    '''
    D = np.arange(len(X[0]))
    Gain_max = 0
    thre_ = 0
    d_ = 0
    for d in D:
        Gain, thre = GainEnt_max(X, y, d)
        if Gain_max < Gain:
            Gain_max = Gain
            thre_ = thre
            d_ = d  # 维度标号

    return Gain_max, thre_, d_


def attribute_based_on_Giniindex(X, y):
    '''
    基于信息增益选择最优属性,X为样本集,y为label
    '''
    D = np.arange(len(X.T))
    Gini_Index_Min = 1
    thre_ = 0
    d_ = 0
    for d in D:
        Gini_index, thre = Gini_index_min(X, y, d)
        if Gini_Index_Min > Gini_index:
            Gini_Index_Min = Gini_index
            thre_ = thre
            d_ = d  # 维度标号

    return Gini_Index_Min, thre_, d_


def devide_group(X, y, thre, d):
    '''
    按照维度d下阈值为thre分为两类并返回
    '''
    X_in_d = X[:, d]
    x_small_index = [i_arg for i_arg in range(
        len(X[:, d])) if X[i_arg, d] <= thre]
    '''
    以上等价于
    x_small_index = []
    
    for i_arg in range(len(X[:, d])):
        if X[i_arg, d] <= thre:
            x_small_index.append(i_arg)
    '''
    x_big_index = [i_arg for i_arg in range(
        len(X[:, d])) if X[i_arg, d] > thre]

    X_small = X[x_small_index]
    y_small = y[x_small_index]
    X_big = X[x_big_index]
    y_big = y[x_big_index]
    return X_small, y_small, X_big, y_big


def NtHt(y):
    '''
    计算经验熵与样本数的乘积,用来剪枝,y为labels
    '''
    ent = entropy(y)
    print('ent={},y_len={},all={}'.format(ent, len(y), ent * len(y)))
    return ent * len(y)


def maxlabel(y):
    label_ = Counter(y).most_common(1)
    return label_[0][0]


def buildtree(X, y, method='Gini'):
    '''
    递归的方式构建决策树
    '''
    if y.size > 1:
        if method == 'Gini':
            Gain_max, thre, d = attribute_based_on_Giniindex(X, y)
        elif method == 'GainEnt':
            Gain_max, thre, d = attribute_based_on_GainEnt(X, y)
        if (Gain_max > 0 and method == 'GainEnt') or (Gain_max >= 0 and len(list(set(y))) > 1 and method == 'Gini'):
            X_small, y_small, X_big, y_big = devide_group(X, y, thre, d)
            left_branch = buildtree(X_small, y_small, method=method)
            right_branch = buildtree(X_big, y_big, method=method)
            nh = NtHt(y)
            max_label = maxlabel(y)
            return decisionnode(d=d, thre=thre, NH=nh, lb=left_branch, rb=right_branch, max_label=max_label)
        else:
            nh = NtHt(y)
            max_label = maxlabel(y)
            return decisionnode(results=y[0], NH=nh, max_label=max_label)
    else:
        nh = NtHt(y)
        max_label = maxlabel(y)
        return decisionnode(results=y.item(), NH=nh, max_label=max_label)


def printtree(tree, indent='-', dict_tree={}, direct='L'):
    # 是否是叶节点

    if tree.results != None:
        print(tree.results)

        dict_tree = {direct: str(tree.results)}

    else:
        # 打印判断条件
        print(str(tree.d) + ":" + str(tree.thre) + "? ")
        # 打印分支
        print(indent + "L->",)

        a = printtree(tree.lb, indent=indent + "-", direct='L')
        aa = a.copy()
        print(indent + "R->",)

        b = printtree(tree.rb, indent=indent + "-", direct='R')
        bb = b.copy()
        aa.update(bb)
        stri = str(tree.d) + ":" + str(tree.thre) + "?"
        if indent != '-':
            dict_tree = {direct: {stri: aa}}
        else:
            dict_tree = {stri: aa}

    return dict_tree


def classify(observation, tree):
    if tree.results != None:
        return tree.results
    else:
        v = observation[tree.d]
        branch = None

        if v > tree.thre:
            branch = tree.rb
        else:
            branch = tree.lb

        return classify(observation, branch)


def pruning(tree, alpha=0.1):
    if tree.lb.results == None:
        pruning(tree.lb, alpha)
    if tree.rb.results == None:
        pruning(tree.rb, alpha)

    if tree.lb.results != None and tree.rb.results != None:
        before_pruning = tree.lb.NH + tree.rb.NH + 2 * alpha
        after_pruning = tree.NH + alpha
        print('before_pruning={},after_pruning={}'.format(
            before_pruning, after_pruning))
        if after_pruning <= before_pruning:
            print('pruning--{}:{}?'.format(tree.d, tree.thre))
            tree.lb, tree.rb = None, None
            tree.results = tree.max_label


if __name__ == '__main__':
    iris = load_iris()
    X = iris.data
    y = iris.target

    #对X.shape[0]间的数随机排序
    permutation = np.random.permutation(X.shape[0])#X.shape[0]=150
    shuffled_dataset = X[permutation, :]
    shuffled_labels = y[permutation]
    #训练集乱序
    train_data = shuffled_dataset[:100, :]
    train_label = shuffled_labels[:100]

    test_data = shuffled_dataset[100:150, :]
    test_label = shuffled_labels[100:150]

    tree1 = buildtree(train_data, train_label, method='Gini')
    print('=============================')
    tree2 = buildtree(train_data, train_label, method='GainEnt')

    a = printtree(tree=tree1)
    b = printtree(tree=tree2)

    true_count = 0
    for i in range(len(test_label)):
        predict = classify(test_data[i], tree1)
        if predict == test_label[i]:
            true_count += 1
    print("CARTTree:{}".format(true_count))
    true_count = 0
    for i in range(len(test_label)):
        predict = classify(test_data[i], tree2)
        if predict == test_label[i]:
            true_count += 1
    print("C3Tree:{}".format(true_count))

    #print(attribute_based_on_Giniindex(X[49:51, :], y[49:51]))
    from pylab import *
    mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
    mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像时负号'-'显示为方块的问题

    import treePlotter
    import matplotlib.pyplot as plt
    treePlotter.createPlot(a, 1)
    treePlotter.createPlot(b, 2)
    # 剪枝处理
    pruning(tree=tree1, alpha=4)
    pruning(tree=tree2, alpha=4)
    a = printtree(tree=tree1)
    b = printtree(tree=tree2)

    true_count = 0
    for i in range(len(test_label)):
        predict = classify(test_data[i], tree1)
        if predict == test_label[i]:
            true_count += 1
    print("CARTTree:{}".format(true_count))
    true_count = 0
    for i in range(len(test_label)):
        predict = classify(test_data[i], tree2)
        if predict == test_label[i]:
            true_count += 1
    print("C3Tree:{}".format(true_count))

    treePlotter.createPlot(a, 3)
    treePlotter.createPlot(b, 4)
    plt.show()

附上treePloter.py

可视化

import matplotlib.pyplot as plt

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="round4", color='#3366FF')  # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='#FF6633')  # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='g')  # 定义箭头

# 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


# 计算叶结点数
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


# 计算树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 在父子结点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center",
                        ha="center", rotation=30)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) /
              2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)  # 在父子结点间填充文本信息
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 绘制带箭头的注释
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff,
                                       plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree, index=1):
    fig = plt.figure(index, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')


结果

ent=0.0,y_len=37,all=0.0
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=1,all=0.0
ent=1.0,y_len=2,all=2.0
ent=0.0,y_len=26,all=0.0
ent=0.22228483068568797,y_len=28,all=6.223975259199263
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=3,all=0.0
ent=0.8112781244591328,y_len=4,all=3.2451124978365313
ent=0.0,y_len=2,all=0.0
ent=1.0,y_len=6,all=6.0
ent=0.0,y_len=29,all=0.0
ent=0.4220005168831531,y_len=35,all=14.770018090910359
ent=0.9983636725938131,y_len=63,all=62.89691137341023
ent=1.579641206421168,y_len=100,all=157.9641206421168

ent=0.0,y_len=37,all=0.0
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=1,all=0.0
ent=1.0,y_len=2,all=2.0
ent=0.0,y_len=26,all=0.0
ent=0.22228483068568797,y_len=28,all=6.223975259199263
ent=0.0,y_len=1,all=0.0
ent=0.0,y_len=3,all=0.0
ent=0.8112781244591328,y_len=4,all=3.2451124978365313
ent=0.0,y_len=2,all=0.0
ent=1.0,y_len=6,all=6.0
ent=0.0,y_len=29,all=0.0
ent=0.4220005168831531,y_len=35,all=14.770018090910359
ent=0.9983636725938131,y_len=63,all=62.89691137341023
ent=1.579641206421168,y_len=100,all=157.9641206421168
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
1:2.45?
----L->
1
----R->
2
—R->
1
–R->
3:1.75?
—L->
3:1.55?
----L->
2:4.95?
-----L->
1
-----R->
2
----R->
1
—R->
2
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
1:2.45?
----L->
1
----R->
2
—R->
1
–R->
3:1.75?
—L->
3:1.55?
----L->
2:4.95?
-----L->
1
-----R->
2
----R->
1
—R->
2
CARTTree:47
C3Tree:47
before_pruning=8.0,after_pruning=6.0
pruning–1:2.45?
before_pruning=10.0,after_pruning=10.223975259199264
before_pruning=8.0,after_pruning=7.245112497836532
pruning–2:4.95?
before_pruning=11.245112497836532,after_pruning=10.0
pruning–3:1.55?
before_pruning=14.0,after_pruning=18.77001809091036
before_pruning=8.0,after_pruning=6.0
pruning–1:2.45?
before_pruning=10.0,after_pruning=10.223975259199264
before_pruning=8.0,after_pruning=7.245112497836532
pruning–2:4.95?
before_pruning=11.245112497836532,after_pruning=10.0
pruning–3:1.55?
before_pruning=14.0,after_pruning=18.77001809091036
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
2
—R->
1
–R->
3:1.75?
—L->
1
—R->
2
2:2.6?
-L->
0
-R->
2:4.85?
–L->
0:5.05?
—L->
2
—R->
1
–R->
3:1.75?
—L->
1
—R->
2
CARTTree:45
C3Tree:45

在这里插入图片描述
在这里插入图片描述
决策树的剪枝策略
决策树的剪枝策略分为预剪枝和后剪枝

预剪枝
预剪枝就是边建立决策时边进行剪枝的操作。在决策树生成的过程中,对每个节点在划分前向首先进行估计,若当前节点的划分不能带来决策树泛化性能的提升,则停止划分并将当前节点标记为叶子节点。

预剪枝可以:限制树的深度,叶子节点个数,叶子节点的样本数,信息增益量等。

后剪枝
当建立完决策树后再进行剪枝操作。后剪枝是先从训练集生成一棵完整的决策树,然后自底向上地对非叶子节点进行考察,若将该节点对应的子树替换为叶子节点能够带来决策树泛化性能的提升,将该子树替换为叶子节点。

通过一定的衡量标准。这里讲的是CART算法的后剪枝方法——代价复杂度算法,即CCP算法。
在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
该资源内项目源码是个人的课程设计、毕业设计,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! ## 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。 该资源内项目源码是个人的课程设计,代码都测试ok,都是运行成功后才上传资源,答辩评审平均分达到96分,放心下载使用! ## 项目备注 1、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 2、本项目适合计算机相关专业(如计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载学习,也适合小白学习进阶,当然也可作为毕设项目、课程设计、作业、项目初期立项演示等。 3、如果基础还行,也可在此代码基础上进行修改,以实现其他功能,也可用于毕设、课设、作业等。 下载后请首先打开README.md文件(如有),仅供学习参考, 切勿用于商业用途。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值