【常见决策树算法逻辑理解以及代码实现(3)】ID3 (代码实现,包含绘图,西瓜书示例)

import math
import matplotlib.pyplot as plt

D = [
['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否']
]
A = ['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']

############################### 信息熵相关算法#####################################

# 当前样本集合D中第k类样本所占比例为pk(k=1,2,3,…,|y|)
# 计算A的信息熵,以数据最后一列为分类
def getEnt(D):
    # 获取一个类型k->出现次数的map
    kMap = dict()
    for dLine in D:
        # 获取分类值k
        k = dLine[len(dLine) - 1]
        # 获取当前k出现的次数
        kNum = kMap.get(k)
        if  kNum is None:
            kMap[k] = 1
        else:
            kMap[k] = kNum + 1
    # 遍历map
    dLen = len(D)
    rs = 0
    for kk in kMap:
        pk = kMap[kk]/dLen
        rs = rs + pk * math.log2(pk)
    return -rs

# 求信息增益,aIndex为属性列号
def getGain(D,aIndex):
    dMap = dict()
    for dLine in D:
        # 获取属性
        k = dLine[aIndex]
        # 属性所属的数组
        dChildren = dMap.get(k)
        if  dChildren is None:
            dChildren = []
            dMap[k] = dChildren
        dChildren.append(dLine)
    rs = 0    
    for key in dMap:
        dChildren = dMap[key]
        entx = getEnt(dChildren)
        r = len(dChildren)/len(D) * entx
        rs = rs + r
    return getEnt(D) - rs

# 求信息增益最大的属性列号
def getMaxtGainIndex(D):
    i = 0
    nowMaxIndex = 0
    nowMaxGain = 0
    while i < len(D[0]) - 1:
        gainI = getGain(D,i)
        print("第:" ,i , "列Gain为:" , gainI)
        if gainI > nowMaxGain:
            nowMaxGain = gainI
            nowMaxIndex = i
        i += 1
    return nowMaxIndex

############################### 辅助算法 #####################################

# 判断D的集合是否是判定同一类型,即全是好瓜或全是坏瓜,返回判定结果以及好坏(为False是第二个参数无效)
def sameCategory(D): 
    allFlag = True
    nowJudge = None
    for d in D:
        # 取最后一列为为好坏瓜
        if nowJudge is None:
            nowJudge = d[len(d) -1]
        else:
            # 只要有一个不等,就继续拆分
            if nowJudge != d[len(d) -1]:
                allFlag = False
                break
    return allFlag,nowJudge

aAStartVMap = dict()
def initAStartV(D):
    if len(aAStartVMap) == 0:
        for dLine in D:
          for index,lable in enumerate(dLine):
            aStart = aAStartVMap.get(index)
            if aStart == None:
                aStart = set()
                aAStartVMap[index] = aStart
            aStart.add(lable)

# 获取指定的D,某一个的每一个属性值的集合AStartV
def getAStartV(index):
    return aAStartVMap[index]

############################### 节点类 #####################################
class Node:
    name = "未命名节点"
    # 线描述,没有的是根节点
    desc = ""
    # 子节点,长度为0的是叶节点
    children = []

        
    def __init__(self):
        self.children = []        

    def addChild(self, node):
        self.children.append(node)

############################### 画树类 #####################################
class Tree:
    root = None
    # 定义决策节点以及叶子节点属性:boxstyle表示文本框类型,sawtooth:锯齿形;circle圆圈,fc表示边框线粗细
    decisionNode = dict(boxstyle="round4", fc="0.5")
    leafNode = dict(boxstyle="circle", fc="0.5")
    # 定义箭头属性
    arrow_args = dict(arrowstyle="<-")
    # 步长,每个节点的横线和纵向距离
    step = 3

    # 当前深度
    deep = 0
    # 当前深度的个数
    nowDeepIndex = 0

    def __init__(self, root):
        self.root = root

        # 设定坐标范围
        plt.xlim(0, 20)
        plt.ylim(-18, 0)
        # 设定中文支持
        plt.rcParams["font.sans-serif"] = ["SimHei"]
        plt.rcParams["axes.unicode_minus"] = False

    # 绘制叶节点
    # x1,y1 箭头起始点坐标
    # x2,y2 箭头目标点(文字点坐标)
    # text  节点文字
    # desc  线文字
    def drawLeaf(self, x1, y1, x2, y2, text, desc):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(x1, y1),
                     xytext=(x2, y2),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.leafNode,
                     arrowprops=self.arrow_args)
        # 绘制线上的文字
        plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)

    # 绘制决策节点
    def drawDecision(self, x1, y1, x2, y2, text, desc):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(x1, y1),
                     xytext=(x2, y2),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.decisionNode,
                     arrowprops=self.arrow_args)
        # 绘制线上的文字
        plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)

    # 绘制根节点(特殊决策节点)
    def drawRoot(self, text):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(0, 0),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.decisionNode)

    def drawTree(self):
        self.draw0(self.root, 0, 0)
        plt.show()

    # 当前深度和这个深度的当前节点数量的映射
    deepIndex = dict()
    # 当前深度
    deep = 0

    # xy是父节点的坐标
    def draw0(self, node, x, y):
        # 如果当前深度节点数量没有,则置为0
        if(self.deepIndex.get(self.deep) is None):
            self.deepIndex[self.deep] = 0
        # 注意因为是基于当前节点数量排列所有节点,故都基于0点排列
        x2 = self.deepIndex[self.deep] * self.step
        y2 = y - self.step
        if len(node.children) > 0:
            if len(node.desc) > 0:
                self.drawDecision(x, y, x2, y2, node.name, node.desc)
                self.deep += 1
                for i, child in enumerate(node.children):
                    self.draw0(child, x2, y2)
                self.deep -= 1
            else:
                self.drawRoot(node.name)
                for i, child in enumerate(node.children):
                    self.draw0(child, 0, 0)
        else:
            self.drawLeaf(x, y, x2, y2, node.name, node.desc)
        # 当前深度节点数++
        self.deepIndex[self.deep] = self.deepIndex[self.deep] + 1


############################### 建立决策树 #####################################
def createTree(D,node,aStart):
    # 创建节点
    newNode = Node()
    if(aStart is not None):
        newNode.desc = aStart
    if node is None:
        node = newNode
    else:
        node.addChild(newNode)

    # 如果这个子集是空集.那么标记为叶节点,返回(参考色泽-浅白)
    if len(D) == 0:
        #  先直接写成是吧,有点问题
         newNode.name = "是"
         return

    #如果D这个子集中,所有的判定都是好瓜或者是坏瓜,没有必要继续下去了,直接设定为叶节点
    allFlag,nowJudge = sameCategory(D)
    # 判断完了,全等,则直接建立为叶节点返回
    if allFlag:
        newNode.name = nowJudge
        return

    # 获取信息增益最高的列index,创建节点,按照这个属性拆分数据为D1,D2,D3...Dn
    index = getMaxtGainIndex(D)
    print("信息增益最高的列index:" , index, "newNode name:",A[index])
    newNode.name = A[index]

    # 一个属性->这个属性的子集的map,将原来的D按照属性拆分为几个子集,这个map的key就是下层个节点的desc
    aStartVMap = dict()


    # 不能直接以D的结果集取找所有属性,会导致属性丢失(此例中会在色泽中丢失)浅白
    for dLine in D:
        dv = aStartVMap.get(dLine[index])
        if dv is None:
            dv = []
            aStartVMap[dLine[index]] = dv
        dv.append(dLine)

    # 获取所有属性,然后比对一下有没有缺的,不上
    allAStartV = getAStartV(index)
    for aa in allAStartV:
        r = aStartVMap.get(aa)
        if r is None:
            aStartVMap[aa] = []

    # 先获取所有的属性,然后以属性遍历
    
    for aStart in aStartVMap:
        createTree(aStartVMap[aStart],newNode,aStart)
    return node
    

initAStartV(D)
root = createTree(D,None,None)
treex = Tree(root)
treex.drawTree()

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值