【常见决策树算法逻辑理解以及代码实现(5)】CART (代码实现,包含绘图,西瓜书示例)

使用的面向对象方式编写,主要类是Cart类,直接传入数据和属性集合,然后draw就可以

运行结果如下(每次运行属性值顺序可能会不同,由于hash问题,不用管,结果是一样的)

全部代码可下载项目https://gitee.com/TomCoCo/mLearn.git

在这里插入图片描述

这里是代码,有完整的注释,可以直接运行如上图

核心方法 createTree

import math
import matplotlib.pyplot as plt
import copy

D = [
['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否']
]
A = ['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']

class Cart:
    # 数据
    data = None
    # 属性集合
    attributes = None
    # 属性集合 (与下标关系),去除最后的类型判定列
    attributesAndIndex = None
    # 属性下标 (与属性可能的取值),去除最后的类型判定列
    attributesIndexAndValue = None
    # 根节点
    root = None

    def __init__(self,data,attributes):
        self.data = data       
        self.attributes = attributes
        self.attributesAndIndex = Cart.getAttributesAndIndex(attributes)
        self.attributesIndexAndValue = Cart.getAttributesAndValue(data,attributes)
        
    def draw(self):
        self.createTree(self.root,self.data,self.attributesAndIndex,None)
        tree = Tree(self.root)
        tree.drawTree()

    # attributesAndIndex 不是类的那个属性了,这个引用会在递归的过程中长度被削减
    def createTree(self,node,data,attributesAndIndex,desc):
        # 创建节点
        newNode = Node()
        # 如果传入了desc,写入
        if(desc is not None):
            newNode.desc = desc
        if node is None:
            self.root = newNode
        else:
            node.addChild(newNode)

        # 如果data中的样本属于同一类别,那么将newNode标记为C类叶节点.返回
        kMap = Cart.getKMap(data)
        if len(kMap) == 1:
            newNode.name = next(iter(kMap.keys()))
            return

        # 如果属性列表是空集,或D在A上的取值相同
        if Cart.checkDA(data,attributesAndIndex):
            # 获取数据集中较多的那个类别
            newNode.name = Cart.getMoreType(data)
            return

        # 获取最优属性下标
        bestIndex = Cart.getMinGiniIndexStrict(data,attributesAndIndex)
        newNode.name = self.attributes[bestIndex]

        # 遍历最优属性的每一个属性值(从原始数据中)
        aStart = self.attributesIndexAndValue[bestIndex]
        # 按最优属性拆分数据,为多个子集
        V = Cart.splitDataByIndex(data,bestIndex)
        for aStartV in aStart:
            dv = V.get(aStartV)
            # 如果dv是空集,那么以获取数据集中较多的那个类别建立子节点
            if dv is None or len(dv) == 0:
                newLeaf = Node()
                newLeaf.name = Cart.getMoreType(data)
                newLeaf.desc = aStartV
                newNode.addChild(newLeaf)
            else:
                 # 将A抛去选中的那个
                Anew = copy.deepcopy(attributesAndIndex)
                for index,item in enumerate(Anew):
                    if next(iter(item.values())) == bestIndex:
                        Anew.pop(index)
                        break
                self.createTree(newNode,dv,Anew,aStartV)


    # 检查D在a上的取值是否完全相同(data的所有数据不一定类别完全相同,只要在a上(可能多个)的取值完全相同即可)
    # 也就是指定类型的那些属性值完全一致,例如下文中的根蒂,脐部.在data上均没有区别都是稍蜷和稍凹
    # 例如 A['根蒂','脐部'] D : 
    #['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
    #['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
    #['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
    #['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
    @staticmethod
    def checkDA(data,attributesAndIndex):
        if len(attributesAndIndex) == 0:
            return True
        for item in attributesAndIndex:
            # 当前的属性值
            nowAttributesValue = None
            # 获取属性下标 i
            aIndex = next(iter(item.values()))
            for dLine in data:
                if nowAttributesValue == None:
                    nowAttributesValue = dLine[aIndex]
                elif nowAttributesValue != dLine[aIndex]:
                    return False
        return True;


    # 将属性附加一个指向数据的哪一个列,删除最后的类别信息,只保留属性信息
    # ['色泽','根蒂'] -> [{'色泽':0},{'根蒂':1}]
    @staticmethod
    def getAttributesAndIndex(attributes):
        attributesAndIndex = list()
        for index,attribute in enumerate(attributes):
            attributesAndIndex.append({attribute:index})
        return attributesAndIndex[:len(attributesAndIndex) - 1]
    

    @staticmethod
    def getAttributesAndValue(data,attributes):
        attributesAndValue = dict()
        for dLine in data:
            for i in range(len(attributes) - 1):
                v = attributesAndValue.get(i)
                if v == None:
                    v = set()
                    attributesAndValue[i] = v
                v.add(dLine[i])
        return attributesAndValue

    # 获取data数据集中,基尼指数最小的那个属性的下标,
    # attributesAndIndex的不需要维度必须和data[]的维度一致.使用attributesAndIndex指定的下标查询.不忽略最后一个
    @staticmethod
    def getMinGiniIndexStrict(data,attributesAndIndex):
        minGiniIndex = None
        minIndex = None
        for item in attributesAndIndex:
            # 获取属性名 v ,属性下标 i
            aName = next(iter(item.keys()))
            aIndex = next(iter(item.values()))
            giniIndex = Cart.getGiniIndex(data,aIndex)
            if minGiniIndex == None or giniIndex < minGiniIndex:
                minGiniIndex = giniIndex
                minIndex = aIndex
            print("第" , aIndex ,"列的属性",aName,"的基尼指数为:" , giniIndex)
        print("第" , minIndex ,"列的属性",aName,"的基尼指数最小为:" , minGiniIndex ,";为最优划分属性")
        return minIndex

    # 获取data数据集中,基尼指数最小的那个属性的下标,attributes的维度必须和data[]的维度一致
    @staticmethod
    def getMinGiniIndex(data,attributes):
        # attributes 的最后一列是类别,不计入
        attributesSize = len(attributes) - 1
        i = 0
        minGiniIndex = None
        minIndex = None
        while i < attributesSize:
            giniIndex = Cart.getGiniIndex(data,i)
            if minGiniIndex == None or giniIndex < minGiniIndex:
                minGiniIndex = giniIndex
                minIndex = i
            print("第" , i ,"列的属性",attributes[i],"的基尼指数为:" , giniIndex)
            i += 1
        print("第" , minIndex ,"列的属性",attributes[minIndex],"的基尼指数最小为:" , minGiniIndex ,";为最优划分属性")
        return minIndex

    
    # 获取基尼指数 data的最后一列认为是类型
    @staticmethod
    def getGiniIndex(data,attributesIndex):
        # 首先按照属性下标(attributesIndex)拆分出多个子集,
        V = Cart.splitDataByIndex(data,attributesIndex)
        # 总数据大小
        dSize = len(data)
        # 计算每个子集的Gini值,加权求和
        rs = 0
        for Dv in V.values():
            dvSize = len(Dv)
            dvGini = Cart.getGini(Cart.getKMap(Dv),dvSize)
            rs += (dvSize/dSize) * dvGini
        return rs
    
    #按照属性下标(attributesIndex)拆分出多个子集,子集的集合为:V,每个子集为Dv
    @staticmethod
    def splitDataByIndex(data,attributesIndex):
        V = dict()
        for dLine in data:
            attribute = dLine[attributesIndex]
            Dv = V.get(attribute)
            if Dv is None:
                Dv = list()
                V[attribute] = Dv
            Dv.append(dLine)
        return V


    # 获取基尼值
    @staticmethod
    def getGini(kMap,dSize):
        rs = 0
        for item in kMap.values():
            pk = (item/dSize)
            rs += pk * pk
        return 1 - rs

    @staticmethod
    def getMoreType(data):
        kMap = Cart.getKMap(data)
        maxCount = -1
        maxName = None
        for key in kMap.keys():
            if kMap.get(key) > maxCount:
                maxCount = kMap.get(key)
                maxName = key
        return maxName


    # 获取指定集合种类型->数量的映射
    @staticmethod
    def getKMap(data):
        kMap = dict()
        for dLine in data:
            # 获取分类值k
            k = dLine[len(dLine) - 1]
            # 获取当前k出现的次数
            kNum = kMap.get(k)
            if  kNum is None:
                kMap[k] = 1
            else:
                kMap[k] = kNum + 1
        return kMap

############################### 节点类 #####################################
class Node:
    name = "未命名节点"
    # 线描述,没有的是根节点
    desc = ""
    # 子节点,长度为0的是叶节点
    children = []

    def __init__(self):
        self.children = []        

    def addChild(self, node):
        self.children.append(node)

############################### 画树类 #####################################
class Tree:
    root = None
    # 定义决策节点以及叶子节点属性:boxstyle表示文本框类型,sawtooth:锯齿形;circle圆圈,fc表示边框线粗细
    decisionNode = dict(boxstyle="round4", fc="0.5")
    leafNode = dict(boxstyle="circle", fc="0.5")
    # 定义箭头属性
    arrow_args = dict(arrowstyle="<-")
    # 步长,每个节点的横线和纵向距离
    step = 3

    # 当前深度
    deep = 0
    # 当前深度的个数
    nowDeepIndex = 0
    # 当前深度和这个深度的当前节点数量的映射
    deepIndex = dict()

    def __init__(self, root):
        self.root = root

        # 设定坐标范围
        plt.xlim(0, 20)
        plt.ylim(-18, 0)
        # 设定中文支持
        plt.rcParams["font.sans-serif"] = ["SimHei"]
        plt.rcParams["axes.unicode_minus"] = False

    # 绘制叶节点
    # x1,y1 箭头起始点坐标
    # x2,y2 箭头目标点(文字点坐标)
    # text  节点文字
    # desc  线文字
    def drawLeaf(self, x1, y1, x2, y2, text, desc):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(x1, y1),
                     xytext=(x2, y2),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.leafNode,
                     arrowprops=self.arrow_args)
        # 绘制线上的文字
        plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)

    # 绘制决策节点
    def drawDecision(self, x1, y1, x2, y2, text, desc):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(x1, y1),
                     xytext=(x2, y2),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.decisionNode,
                     arrowprops=self.arrow_args)
        # 绘制线上的文字
        plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)

    # 绘制根节点(特殊决策节点)
    def drawRoot(self, text):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(0, 0),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.decisionNode)

    def drawTree(self):
        self.draw0(self.root, 0, 0)
        plt.show()

    # xy是父节点的坐标
    def draw0(self, node, x, y):
        # 如果当前深度节点数量没有,则置为0
        if(self.deepIndex.get(self.deep) is None):
            self.deepIndex[self.deep] = 0
        # 注意因为是基于当前节点数量排列所有节点,故都基于0点排列
        x2 = self.deepIndex[self.deep] * self.step
        y2 = y - self.step
        if len(node.children) > 0:
            if len(node.desc) > 0:
                self.drawDecision(x, y, x2, y2, node.name, node.desc)
                self.deep += 1
                for i, child in enumerate(node.children):
                    self.draw0(child, x2, y2)
                self.deep -= 1
            else:
                self.drawRoot(node.name)
                for i, child in enumerate(node.children):
                    self.draw0(child, 0, 0)
        else:
            self.drawLeaf(x, y, x2, y2, node.name, node.desc)
        # 当前深度节点数++
        self.deepIndex[self.deep] = self.deepIndex[self.deep] + 1


# 程序入口
cart = Cart(D,A)
cart.draw()


###########测试checkDA方法##########
# D2=[
# ['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
# ['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
# ['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
# ['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否']]
# AA2 = [{'根蒂':1},{'脐部':4}]
# a = Cart.checkDA(D2,AA2)
# a2 = Cart.checkDA(D,[])
# a3 = Cart.checkDA(D,Cart.getAttributesAndIndex(A))
# print(a,a2,a3)


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,我会尽力为您解答。首先,决策树算法是一种基于树结构的分类算法,可以用于预测事物的分类或属性。而Titanic乘客存活率预测是一个经典的机器学习案例,我们可以通过优化代码来提高预测的准确性和效率。 以下是一些可以优化决策树算法实现Titanic乘客存活率预测的方法: 1. 数据预处理:在进行决策树算法之前,我们需要对数据进行预处理,包括缺失值处理、数据类型转换、特征选择等。可以使用pandas库来进行数据的读取和处理。 2. 特征工程:特征工程是指对原始数据进行特征提取和转换的过程,可以提高模型的预测准确性。可以使用sklearn库中的特征选择工具和特征转换工具来进行特征工程。 3. 模型选择和调参:决策树算法有多种实现方式,如ID3、C4.5、CART等。可以通过交叉验证和网格搜索等方法来选择最优的模型和参数。 4. 模型评估:在使用决策树算法进行预测时,需要对模型进行评估,包括准确率、精确率、召回率等指标。可以使用sklearn库中的评估工具来进行模型评估。 下面是一个决策树算法实现Titanic乘客存活率预测的示例代码: ```python import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score # 读取数据 data = pd.read_csv('train.csv') # 数据预处理 data['Age'].fillna(data['Age'].mean(), inplace=True) data.drop(['Cabin', 'Name', 'Ticket'], axis=1, inplace=True) data['Embarked'].fillna('S', inplace=True) data['Sex'] = data['Sex'].map({'male': 0, 'female': 1}) data['Embarked'] = data['Embarked'].map({'S': 0, 'C': 1, 'Q': 2}) # 特征工程 X = data.drop(['Survived'], axis=1) y = data['Survived'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) # 模型训练 clf = DecisionTreeClassifier() clf.fit(X_train, y_train) # 模型预测 y_pred = clf.predict(X_test) print('Accuracy:', accuracy_score(y_test, y_pred)) ``` 以上代码实现了数据预处理、特征工程、模型训练和模型预测等步骤,通过使用sklearn库中的决策树算法进行预测,并计算了预测的准确率。需要注意的是,这只是一个简单的示例代码,实际优化还需要考虑更多的因素,如特征选择、模型调参、集成学习等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值