原理介绍
简要介绍
决策树算法是一个分类算法(监督学习),通过训练会获得一颗树形的分类模型。树的每一个非叶子节点都是一个判断条件,叶子节点是一个分类。
举个例子:我打算出去玩,我会先看看明天天气怎么样,如果下雨,我就不出去玩了【叶子节点】,如果不下雨,那么我再看看温度怎么样,低于10度就不出去玩了【叶子节点】,10度及以上则出去玩【叶子节点】。
原理
那么怎么获得这棵树呢?
假设一组数据有n个特征,我们想要高效的进行判断,会希望根节点是最具有区分度的,即如果我能通过根节点就分类成功最好。如果无法通过根节点进行分类,那么能通过第二层节点分类成功也不错。以此类推,通过尽量少的节点进行分类。
如何获得最有区分度的根节点呢?其实就是找到一个最具区分度的特征。我们假设我们有一个函数,你把每个特征数据穿给它,它便可以计算出一个分数,这个分数越大,代表特征越有区分度。
函数逻辑比较复杂,放到后面单独说,假设我们已经有了这个函数,那么我首先通过第一轮计算,选出了n个特征中最具区分度的特征A。假设A可以被分为3类(三种类型的数据,比如A特征表示大小,数据中有:大、中、小三种),则我们将A最为根节点,并向下延伸产生3个分支。并且按照A的三个分类将数据分为三组,去掉A特征,计算剩余特征中最具区分度的特征。以此类推形成树。
是否继续向下划分的判断依据是:
- 是否某一特征下的label都一致,如果一致,则不在继续往下区分(因为再怎么分,label都不变了)。
- 没有更多的特征进行划分,则结束(看当前特征哪个label占比更高则结果为此label)
得分函数(信息熵)
这个得分函数常见的有三种:ID3,C4.5,CART,这里介绍其中一种:ID3.
ID3的核心是: 信息熵。何为信息熵呢?我们都学过熵增现象,比如你滴一滴墨水到一杯清水里,墨水会散开到整杯水,且不会自动聚拢。而信息熵所描绘的就是一个信息的确定性,类比一下,如果一滴墨汁滴到一滴水里,那么还是乌黑,我们很容易确定这是墨汁。但如果把一滴水滴到一个大水缸里,那你还能确定当前水缸里的是墨汁吗?所以:
- 信息熵越大,则信息的确定性越低。
- 信息熵越小,则信息的确定性越高。
再回到我们上面原理部分的描述,我们要的是确定性高的特征,即最好当前节点(特征)就能判别出来结果,所以我们要求信息熵尽可能的小,即当前特征的确定性越高越好。那怎么计算信息熵呢?很早之前伟大的香农博士就已经帮我们定义好了计算信息熵的方法:
H
(
x
)
=
−
∑
i
m
p
(
x
)
log
2
p
(
x
)
H(x)=-\sum_{i}^{m}p(x)\log_{2}{p(x)}
H(x)=−∑imp(x)log2p(x)
此处的p(x)表示概率。这里假设特征A有三类数据:a,b,c。
- a中有3条数据,1条数据对应分类1,2条数据对应分类二。
- b中有4条数据,2条数据对应分类1,2条数据对应分类二。
- c中有4条数据,1条数据对应分类1,3条数据对应分类二。
则有A的信息熵为:
3
11
∗
(
−
1
3
∗
log
2
1
3
−
2
3
∗
log
2
2
3
)
+
4
11
∗
(
−
2
4
∗
log
2
2
4
−
2
4
∗
log
2
2
4
)
+
4
11
∗
(
−
1
4
∗
log
2
1
4
−
3
4
∗
log
2
3
4
)
\frac{3}{11} * (-\frac{1}{3}*\log_{2}{\frac{1}{3}} -\frac{2}{3}*\log_{2}{\frac{2}{3}}) + \frac{4}{11} * (-\frac{2}{4}*\log_{2}{\frac{2}{4}} -\frac{2}{4}*\log_{2}{\frac{2}{4}}) + \frac{4}{11} * (-\frac{1}{4}*\log_{2}{\frac{1}{4}} -\frac{3}{4}*\log_{2}{\frac{3}{4}})
113∗(−31∗log231−32∗log232)+114∗(−42∗log242−42∗log242)+114∗(−41∗log241−43∗log243)
实战
数据集
鸢尾花数据集:https://www.kaggle.com/datasets/himanshunakrani/iris-dataset
数据处理
origin_data = pd.read_csv("/kaggle/input/iris-dataset/iris.csv")
origin_data.loc[:,'species'].value_counts()
data = origin_data.replace({'species':{'setosa':1,'versicolor':2,'virginica':3}})
data.head()
X = data.drop(columns=['species'])
y = data.loc[:,'species']
训练
from sklearn import tree
model = tree.DecisionTreeClassifier(criterion='entropy', min_samples_leaf=5)
model.fit(X,y)
- criterion=‘entropy’ : 使用信息熵方式计算(默认)
- min_samples_leaf=5:每次向下进行分类,样本数最少为5个
预测+评估
y_predict = model.predict(X)
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y,y_predict)
print(accuracy)
输出:0.9733333333333334
绘制决策树
from matplotlib import pyplot as plt
plt.figure(figsize=(10,10))
tree.plot_tree(model, filled=True, feature_names=['sepal_length','sepal_width','petal_length','petal_width'],class_names=['setosa','versicolor','virginica'])
- filled: 是否填充颜色
- feture_names: 按原列表顺序写
- class_names: 处理后的数字升序排列后对应的名称