python ID3 决策树 代码

 代码参考了python实现ID3决策树分类算法_aoanng的博客-CSDN博客_id3算法python实现

有所精简 


'''
function:ID3决策树生成算法
author:baomi
date: 2021/11/01
reference: https://blog.csdn.net/colourful_sky/article/details/82056125
'''

import math

def splitDataSet(dataSet, i, value):
    '''
    返回数据集dataSet中,去掉第i列属性值为value的实例后形成的新的数据集
    '''
    retDataSet = []
    for x in dataSet:
        if x[i] == value:
            temp = x[:]
            temp.pop(i)
            retDataSet.append(temp)
    return retDataSet


def calcEntropy(dataSet):
    '''
    计算一个数据集的熵
    '''
    labelDict = {}  # 数据集的标签-该标签总个数
    for x in dataSet:
        label = x[-1]
        if label not in labelDict.keys():
            labelDict[label] = 0
        labelDict[label] += 1

    n = len(dataSet)
    retEntropy = 0.0
    for key in labelDict:
        p = float(labelDict[key]) / n  # 计算标签概率
        retEntropy -= p * math.log(p, 2)

    return retEntropy


def calcInfoGain(dataSet, i):
    '''
    计算对数据集dataSet,选定第i列特征时所获得的信息增益
    '''
    preEntropy = calcEntropy(dataSet)

    postEntropy = 0.0
    featureSet = set([x[i] for x in dataSet])  # 得到i列特征所有特征值的集合
    for feature in featureSet:  # 以feature为筛选条件,计算筛选后的数据集熵
        subDataSet = splitDataSet(dataSet, i, feature)
        subDataSetEntropy = calcEntropy(subDataSet)
        p = len(subDataSet)/len(dataSet)
        postEntropy += p*subDataSetEntropy

    return preEntropy-postEntropy

def getMaxInfoGainNode(dataSet, featureNameList):
    '''
    featureNameList是dataSet中各特征名称
    该函数返回两种结果:
    熵为0时,返回标签,类型为str
    熵不为0时,返回具有最大信息增益的特征的索引号,特征名,以及最大信息增益
    '''
    dataSetEntropy = calcEntropy(dataSet)
    if dataSetEntropy == 0:
        return dataSet[0][-1]  # 数据集熵为0,说明标签都相同,直接将该标签返回

    featureNum = len(featureNameList)

    maxInfoGain = 0
    maxInfoGainIndex = 0
    for i in range(0, featureNum): #遍历所有特征,获得具有最大信息增益的特征索引号
        infoGain = calcInfoGain(dataSet, i)
        if infoGain > maxInfoGain:
            maxInfoGain = infoGain
            maxInfoGainIndex = i

    return maxInfoGainIndex, featureNameList[maxInfoGainIndex], maxInfoGain


def createID3Tree(dataSet, featureNameList):
    '''
    该函数返回一个结点
    如果dataSet熵为0,那么返回dataSet中类标签,此标签唯一
    否则,返回一个字典,该字典的key为dataSet选出的最优特征名,value又为一个字典,
    value字典的key为最优特征的特征值名,value字典的value又为一个字典.....
    '''
    maxInfoGainNode = getMaxInfoGainNode(dataSet, featureNameList)
    if type(maxInfoGainNode) == str:
        return maxInfoGainNode

    nodeIndex, nodeName = maxInfoGainNode[0], maxInfoGainNode[1]
    ret = {}
    ret[nodeName] = {}

    featureSet = set([x[nodeIndex] for x in dataSet])
    for feature in featureSet:
        subDataSet = splitDataSet(dataSet, nodeIndex, feature)
        newFeatNameList = featureNameList[:]
        newFeatNameList.pop(nodeIndex)
        childTree = createID3Tree(subDataSet, newFeatNameList) #对以最大信息增益作为特征筛选后的子数据集进行递归调用
        ret[nodeName][feature] = childTree

    return ret

dataSet = [['青年', '否', '否', '一般', '拒绝'],
           ['青年', '否', '否', '好', '拒绝'],
           ['青年', '是', '否', '好', '同意'],
           ['青年', '是', '是', '一般', '同意'],
           ['青年', '否', '否', '一般', '拒绝'],
           ['中年', '否', '否', '一般', '拒绝'],
           ['中年', '否', '否', '好', '拒绝'],
           ['中年', '是', '是', '好', '同意'],
           ['中年', '否', '是', '非常好', '同意'],
           ['中年', '否', '是', '非常好', '同意'],
           ['老年', '否', '是', '非常好', '同意'],
           ['老年', '否', '是', '好', '同意'],
           ['老年', '是', '否', '好', '同意'],
           ['老年', '是', '否', '非常好', '同意'],
           ['老年', '否', '否', '一般', '拒绝'], ]

featureNameList = ['年龄', '有工作', '有房子', '信贷情况']


ID3Tree = createID3Tree(dataSet, featureNameList)
print(ID3Tree)

输出结果:

对结果用matplotlib绘图

代码来自Matplotlib绘制树形图_wancongconghao的博客-CSDN博客_matplotlib 树状图 

#绘制树形图

import matplotlib.pyplot as plt

decision_node = dict(boxstyle="sawtooth",fc="0.8")

leaf_node = dict(boxstyle="round4",fc="0.8")

arrow_args = dict(arrowstyle="<-")

#获取树的叶子结点个数(确定图的宽度)

def get_leaf_num(tree):

    leaf_num = 0

    first_key = list(tree.keys())[0]

    next_dict = tree[first_key]

    for key in next_dict.keys():

        if type(next_dict[key]).__name__=="dict":

            leaf_num +=get_leaf_num(next_dict[key])

        else:

            leaf_num +=1

    return leaf_num

#获取数的深度(确定图的高度)

def get_tree_depth(tree):

    depth = 0

    first_key = list(tree.keys())[0]

    next_dict = tree[first_key]

    for key in next_dict.keys():

        if type(next_dict[key]).__name__ == "dict":

            thisdepth = 1+ get_tree_depth(next_dict[key])

        else:

            thisdepth = 1

        if thisdepth>depth: depth = thisdepth

    return depth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):

    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',

                            xytext=centerPt, textcoords='axes fraction',

                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

#在父子节点间填充文本信息

def plotMidText(cntrPt, parentPt, txtString):

    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]

    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]

    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):

    numLeafs = get_leaf_num(myTree)

    depth = get_tree_depth(myTree)

    firstStr = list(myTree.keys())[0]

    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)

    plotMidText(cntrPt, parentPt, nodeTxt)

    plotNode(firstStr, cntrPt, parentPt, decision_node)

    secondDict = myTree[firstStr]

    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD

    for key in secondDict.keys():

        if type(secondDict[

                    key]).__name__ == 'dict':

            plotTree(secondDict[key], cntrPt, str(key))

        else:

            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW

            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node)

            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD



 

def createPlot(inTree):

    fig = plt.figure(1, facecolor='white')

    fig.clf()

    axprops = dict(xticks=[], yticks=[])

    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)

    plotTree.totalW = float(get_leaf_num(inTree))

    plotTree.totalD = float(get_tree_depth(inTree))

    plotTree.xOff = -0.5 / plotTree.totalW

    plotTree.yOff = 1.0

    plotTree(inTree, (0.5, 1.0), '')

    plt.show()

plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签

plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

tree = {'有房子': {'否': {'有工作': {'否': '拒绝', '是': '同意'}}, '是': '同意'}}

createPlot(tree)

结果:

  • 0
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是一个示例的 Python ID3 决策树代码: ```python import numpy as np def entropy(y): """ 计算给定数据集 y 的熵值 """ unique, count = np.unique(y, return_counts=True) probs = count / len(y) entropy = np.sum(-probs * np.log2(probs)) return entropy def information_gain(X, y, feature_index): """ 计算给定特征的信息增益 """ parent_entropy = entropy(y) unique_vals = np.unique(X[:, feature_index]) weighted_entropy = 0 for value in unique_vals: subset = y[X[:, feature_index] == value] subset_entropy = entropy(subset) weighted_entropy += (len(subset) / len(y)) * subset_entropy information_gain = parent_entropy - weighted_entropy return information_gain def id3(X, y, features): """ 使用 ID3 算法构建决策树 """ # 如果所有实例都属于同一类别,则返回这个类别 if len(np.unique(y)) == 1: return y[0] # 如果没有特征可供划分,则返回实例中出现次数最多的类别 if len(features) == 0: unique, count = np.unique(y, return_counts=True) return unique[np.argmax(count)] # 计算所有特征的信息增益,并选择最大增益的特征 gains = [information_gain(X, y, i) for i in range(len(features))] best_feature_index = np.argmax(gains) best_feature = features[best_feature_index] # 创建一个新的决策树节点 tree = {best_feature: {}} # 从特征列表中移除已选择的特征 features = np.delete(features, best_feature_index) # 递归构建子树 unique_vals = np.unique(X[:, best_feature_index]) for value in unique_vals: subset_indices = np.where(X[:, best_feature_index] == value) subset_X = X[subset_indices] subset_y = y[subset_indices] subtree = id3(subset_X, subset_y, features) tree[best_feature][value] = subtree return tree ``` 这段代码实现了一个简单的 ID3 决策树算法,可以用于分类问题。它包括计算熵、计算信息增益、递归构建决策树等功能。你可以根据自己的需求进行修改和优化。希望对你有所帮助!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值