【ML】决策树(Decision tree)原理 + 实践 (基于sklearn)

16 篇文章 0 订阅
12 篇文章 0 订阅

原理介绍

简要介绍

决策树算法是一个分类算法(监督学习),通过训练会获得一颗树形的分类模型。树的每一个非叶子节点都是一个判断条件,叶子节点是一个分类。
举个例子:我打算出去玩,我会先看看明天天气怎么样,如果下雨,我就不出去玩了【叶子节点】,如果不下雨,那么我再看看温度怎么样,低于10度就不出去玩了【叶子节点】,10度及以上则出去玩【叶子节点】。

原理

那么怎么获得这棵树呢?
假设一组数据有n个特征,我们想要高效的进行判断,会希望根节点是最具有区分度的,即如果我能通过根节点就分类成功最好。如果无法通过根节点进行分类,那么能通过第二层节点分类成功也不错。以此类推,通过尽量少的节点进行分类。

如何获得最有区分度的根节点呢?其实就是找到一个最具区分度的特征。我们假设我们有一个函数,你把每个特征数据穿给它,它便可以计算出一个分数,这个分数越大,代表特征越有区分度。

函数逻辑比较复杂,放到后面单独说,假设我们已经有了这个函数,那么我首先通过第一轮计算,选出了n个特征中最具区分度的特征A。假设A可以被分为3类(三种类型的数据,比如A特征表示大小,数据中有:大、中、小三种),则我们将A最为根节点,并向下延伸产生3个分支。并且按照A的三个分类将数据分为三组,去掉A特征,计算剩余特征中最具区分度的特征。以此类推形成树。
是否继续向下划分的判断依据是:

  1. 是否某一特征下的label都一致,如果一致,则不在继续往下区分(因为再怎么分,label都不变了)。
  2. 没有更多的特征进行划分,则结束(看当前特征哪个label占比更高则结果为此label)

得分函数(信息熵)

这个得分函数常见的有三种:ID3,C4.5,CART,这里介绍其中一种:ID3.

ID3的核心是: 信息熵。何为信息熵呢?我们都学过熵增现象,比如你滴一滴墨水到一杯清水里,墨水会散开到整杯水,且不会自动聚拢。而信息熵所描绘的就是一个信息的确定性,类比一下,如果一滴墨汁滴到一滴水里,那么还是乌黑,我们很容易确定这是墨汁。但如果把一滴水滴到一个大水缸里,那你还能确定当前水缸里的是墨汁吗?所以:

  1. 信息熵越大,则信息的确定性越低。
  2. 信息熵越小,则信息的确定性越高。

再回到我们上面原理部分的描述,我们要的是确定性高的特征,即最好当前节点(特征)就能判别出来结果,所以我们要求信息熵尽可能的小,即当前特征的确定性越高越好。那怎么计算信息熵呢?很早之前伟大的香农博士就已经帮我们定义好了计算信息熵的方法:
H ( x ) = − ∑ i m p ( x ) log ⁡ 2 p ( x ) H(x)=-\sum_{i}^{m}p(x)\log_{2}{p(x)} H(x)=imp(x)log2p(x)
此处的p(x)表示概率。这里假设特征A有三类数据:a,b,c。

  1. a中有3条数据,1条数据对应分类1,2条数据对应分类二。
  2. b中有4条数据,2条数据对应分类1,2条数据对应分类二。
  3. c中有4条数据,1条数据对应分类1,3条数据对应分类二。

则有A的信息熵为:
3 11 ∗ ( − 1 3 ∗ log ⁡ 2 1 3 − 2 3 ∗ log ⁡ 2 2 3 ) + 4 11 ∗ ( − 2 4 ∗ log ⁡ 2 2 4 − 2 4 ∗ log ⁡ 2 2 4 ) + 4 11 ∗ ( − 1 4 ∗ log ⁡ 2 1 4 − 3 4 ∗ log ⁡ 2 3 4 ) \frac{3}{11} * (-\frac{1}{3}*\log_{2}{\frac{1}{3}} -\frac{2}{3}*\log_{2}{\frac{2}{3}}) + \frac{4}{11} * (-\frac{2}{4}*\log_{2}{\frac{2}{4}} -\frac{2}{4}*\log_{2}{\frac{2}{4}}) + \frac{4}{11} * (-\frac{1}{4}*\log_{2}{\frac{1}{4}} -\frac{3}{4}*\log_{2}{\frac{3}{4}}) 113(31log23132log232)+114(42log24242log242)+114(41log24143log243)

实战

数据集

鸢尾花数据集:https://www.kaggle.com/datasets/himanshunakrani/iris-dataset

数据处理

origin_data = pd.read_csv("/kaggle/input/iris-dataset/iris.csv")
origin_data.loc[:,'species'].value_counts()
data = origin_data.replace({'species':{'setosa':1,'versicolor':2,'virginica':3}})
data.head()
X = data.drop(columns=['species'])
y = data.loc[:,'species']

训练

from sklearn import tree
model = tree.DecisionTreeClassifier(criterion='entropy', min_samples_leaf=5)
model.fit(X,y)
  • criterion=‘entropy’ : 使用信息熵方式计算(默认)
  • min_samples_leaf=5:每次向下进行分类,样本数最少为5个

预测+评估

y_predict = model.predict(X)

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y,y_predict)
print(accuracy)

输出:0.9733333333333334

绘制决策树

from matplotlib import pyplot as plt
plt.figure(figsize=(10,10))
tree.plot_tree(model, filled=True, feature_names=['sepal_length','sepal_width','petal_length','petal_width'],class_names=['setosa','versicolor','virginica'])
  • filled: 是否填充颜色
  • feture_names: 按原列表顺序写
  • class_names: 处理后的数字升序排列后对应的名称
    在这里插入图片描述
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值