【机器学习】基础入门(2)决策树原理及代码实现

  • 了解决策树的理论知识
  • 掌握 决策树的 sklearn 函数调用使用并将其运用到企鹅数据集预测

1 算法原理

与上一部分的逻辑回归的相比,逻辑回归是将所有特征变换为概率后,通过大于某一概率阈值的划分为一类,小于某一概率阈值的为另一类;而决策树是对每一个特征做一个划分。
树形模型更加接近人的思维方式,可以产生可视化的分类规则,产生的模型具有可解释性(可以抽取规则)。

1.1 什么是决策树

决策树是一个类似于流程图的树形结构,树内部的每一个节点代表的是对一个特征的测试,树的分支代表该特征的每一个测试结果,而树的每一个叶子节点代表一个类别。如下所示。
在这里插入图片描述
决策树采用自顶向下的递归的方法,基本思想是以信息熵为度量构造一棵熵值下降最快的树,到叶子节点处熵值为0(叶节点中的实例都属于一类)。

1.3 学习过程

1.3.1 不纯度(GINI系数&Entropy熵)

决策树思想,实际上就是寻找最纯净的划分方法,这个最纯净在数学上叫纯度,纯度通俗点理解就是目标变量要分得足够开,节点内尽量都是“同类”的数据。
“不纯度” impurity来度量

  1. 不纯度(impurity)–GINI系数

在这里插入图片描述
一个简单的计算示例如下图:(GINI值越小,纯度越高)
在这里插入图片描述

  1. 不纯度(impurity)–Entropy熵
  • 信息熵是一种衡量数据混乱程度的指标,信息熵越小,则数据的“纯度”越高。
    设X是一个取有限个值的离散随机变量,其概率分布为:
    在这里插入图片描述
    则随机变量X的熵定义为 :
    在这里插入图片描述

1.3.2 特征选择

特征选择是指从训练数据中众多的特征中选择一个特征作为当前节点的分裂标准,如何选择特征有着很多不同量化评估标准标准,从而衍生出不同的决策树算法。

信息增益法:选择具有最高信息增益的特征作为测试特征,利用该特征对节点样本进行划分子集,会使得各子集中不同类别样本的混合程度最低,在各子集中对样本划分所需的信息(熵)最少。
(信息增益既可以用熵也可以用GINI系数来计算)

信息增益计算:
在这里插入图片描述
给出如下示例更好的理解计算过程:

在这里插入图片描述
在这里插入图片描述

1.3.3 建树

根据选择的特征评估标准,从上至下递归地生成子节点,直到数据集不可分则停止决策树停止生长。

  1. 根节点出发,根节点包括所有的训练样本。一个节点(包括根节点),若节点内所有样本均属于同一类别,那么将该节点就成为叶节点,并将该节点标记为样本个数最多的类别
  2. 采用信息增益法来选择用于对样本进行划分的特征,该特征即为测试特征,特征的每一个值都对应着从该节点产生的一个分支及被划分的一个子集。决策树中特征均为符号值,即离散值。若非连续值则要将其离散化。
  3. 递归上述划分子集及产生叶节点的过程,这样每一个子集都会产生一个决策(子)树,直到所有节点变成叶节点。
  4. 递归操作的停止条件:
    • 一个节点中所有的样本均为同一类别,即产生叶节点
    • 没有特征可以用来对该节点样本进行划分,此时也强制产生叶节点
    • 没有样本能满足剩余特征的取值,此时也强制产生叶节点

1.3.4 剪枝

决策树容易过拟合,理论上可以完全分得开数据。所以一般来需要剪枝,缩小树结构规模、缓解过拟合。剪枝技术有预剪枝和后剪枝两种。

  1. 预剪枝(Pre-Pruning):边建立决策树边进行剪枝 ---->更实用
    在决策树生成分支的过程中,设定一些规则来避免树过度生长:

    • 信息增益(率)小于阈值就不再分裂
    • 节点样本数小于阈值(例如1%)就不再分裂
    • 若分裂后叶节点样本数少于阈值(例如0.5%)就不再分裂
    • 树深度大于阈值(例如8层)就不再分裂
  2. 后剪枝(Post-Pruning):当建立完决策树后来进行剪枝
    根据每个分支的分类错误率及每个分支的权重,计算该节点不修剪时预期分类错误率。
    对于每个非叶节点,如果修剪后分类错误率变大,即放弃修剪;否则将该节点强制为叶节点,并标记类别。
    产生一系列修剪过的决策树候选之后,利用测试数据对各候选决策树的分类准确性进行评价,保留分类错误率最小的决策树。

1.4 决策树三种常用算法

1.4.1 ID3算法/基本决策树

ID3算法是最早提出的一种决策树算法,ID3算法的核心是在决策树各个节点上应用信息增益准则来选择特征,递归的构建决策树。

具体方法:
从根节点开始,对节点计算所有可能的特征的信息增益,选择信息增益最大的特征作为节点的特征,由该特征的不同取值建立子节点
再对子节点递归的调用以上方法,构建决策树
直到所有的特征信息增益均很小或没有特征可以选择为止。

问题:
在这里插入图片描述
选择ID作为特征,信息增益最大,可是这个特征意义不大,每个ID必然只对应一个类别。
故信息增益的问题的就从这里引发出来,它的缺点就是偏向选择取值较多的属性。

1.4.2 C4.5算法

首先介绍下什么是信息增益率
信息增益率(比):信息增益除以该属性本身的熵
在这里插入图片描述
利用信息增益率对多分叉进行惩罚,避免了ID3算法中的归纳偏置问题。

C4.5算法与ID3算法决策树的生成过程相似, 改用信息增益率(比)来选择特征。主要是改进了样本特征部分:

  1. 基本决策树要求特征A取值为离散值,对连续值可采用如将连续值按段进行划分,然后设置哑变量等方式。
  2. 特征A的每个取值都会产生一个分支,可能会因为划分出的子集样本量过小停止继续分支,强制标记类别后可能会产生局部错误。可采用A的一组取值作为分支条件;或采用二元决策树,每一个分支代表一个特征取值的情况(只有是否两种取值)。
  3. 某些样本在特征A上值缺失,可用其他样本中特征A出现最多的值来填充,或均值、中值等,有些也可用样本内部的平滑来补值,当样本量很大时也可丢弃缺失值样本。
  4. 数据集不断减小,子集样本量也越来越小,所构造出的决策树就可能出现碎片、重复、复制等问题,可以利用样本的原有特征构造新的特征进行建模。
  5. 信息增益法会倾向于选择取值比较多的特征(这是信息熵的定义决定了的)。增益比率法(gain ratio)将每个特征取值的概率考虑在内,及gini索引法,χ2χ2条件统计表法和G统计法等。

问题:
模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处理分类不能处理回归等

1.4.3 CART算法

既可以做分类,也可以做回归,只能形成二叉树
CART算法稍微复杂一些,马一篇写的很好的博文来学习。CART算法详解

1.5 决策树算法的参数

sklearn.tree.DecisionTreeClassifier(criterion=‘gini’, splitter=‘best’, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=False)

树模型参数:

  • 1.criterion: 决定模型特征选择的计算方法,gini or entropy
  • 2.splitter: best or random 前者是在所有特征中找最好的切分点 后者是在部分特征中(数据量大的时候)
  • 3.max_features:限制树的最大深度,超过设定深度的树枝全部剪掉,限制树深度能够有效地限制过拟合。None(所有),log2,sqrt,N ,特征小于50的时候一般使用所有的
  • 4.max_depth: 数据少或者特征少的时候可以不管这个值,如果模型样本量多,特征也多的情况下,可以尝试限制下**(预剪枝)**
  • 5.min_samples_split: 如果某节点的样本数少于min_samples_split,则不会继续再尝试选择最优特征来进行划分如果样本量不大,不需要管这个值。如果样本量数量级非常大,则推荐增大这个值。(预剪枝)
  • 6.min_samples_leaf: 这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝,如果样本量不大,不需要管这个值,大些如10W可是尝试下5**(预剪枝)**
  • 7 min_weight_fraction_leaf:这个值限制了叶子节点所有样本权重和的最小值,如果小于这个值,则会和兄弟节点一起被剪枝默认是0,就是不考虑权重问题。一般来说,如果我们有较多样本有缺失值,或者分类树样本的分布类别偏差很大,就会引入样本权重,这时我们就要注意这个值了。
  • 8.max_leaf_nodes 通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。如果加了限制,算法会建立在最大叶子节点数内最优的决策树。如果特征不多,可以不考虑这个值,但是如果特征分成多的话,可以加以限制具体的值可以通过交叉验证得到。(预剪枝)
  • 9.class_weight:指定样本各类别的的权重,主要是为了防止训练集某些类别的样本过多导致训练的决策树过于偏向这些类别。这里可以自己指定各个样本的权重如果使用“balanced”,则算法会自己计算权重,样本量少的类别所对应的样本权重会高。
  • 10.min_impurity_split:这个值限制了决策树的增长,如果某节点的不纯度(基尼系数,信息增益,均方差,绝对差)小于这个阈值则该节点不再生成子节点。即为叶子节点 。(预剪枝)
  • n_estimators:要建立树的个数

1.6 决策树算法总结

  • 优点

    • 简单直观,生成的决策树很直观
       - 基本不需要预处理,不需要提前归一化,处理缺失值
       - 预测的代价是O(log2m)。 m为样本数
       - 既可以处理离散值也可以处理连续值。很多算法只是专注于离散值或者连续值
       - 可以处理多维度输出的分类问题
    • 相比于神经网络之类的黑盒分类模型,决策树在逻辑上可以得到很好的解释
    • 可以交叉验证的剪枝来选择模型,从而提高泛化能力
    • 对于异常点的容错能力好,健壮性高
  • 缺点

    • 决策树算法非常容易过拟合,导致泛化能力不强。可以通过设置节点最少样本数量和限制决策树深度来改进。
       - 决策树会因为样本发生一点点的改动(特别是在节点的末梢),导致树结构的剧烈改变。这个可以通过集成学习之类的方法解决。
       - 寻找最优的决策树是一个NP难的问题,我们一般是通过启发式方法,容易陷入局部最优。可以通过集成学习之类的方法来改善。
       - 有些比较复杂的关系,决策树很难学习,比如异或。这个就没有办法了,一般这种关系可以换神经网络分类方法来解决。
       - 如果某些特征的样本比例过大,生成决策树容易偏向于这些特征。这个可以通过调节样本权重来改善。

2 Demo实践

2.1 模型训练

# Demo演示DecisionTree分类

# 导入决策树模型函数
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

# 构造数据集
x_fearures = np.array([[-1, -2], [-2, -1], [-3, -2], [1, 3], [2, 1], [3, 2]])
y_label = np.array([0, 1, 0, 1, 0, 1])

# 用决策树模型拟合构造的数据集
tree_clf = DecisionTreeClassifier()
tree_clf = tree_clf.fit(x_fearures, y_label)

2.2 可视化

# 可视化构造的数据样本点
plt.scatter(x_fearures[:,0],x_fearures[:,1], c=y_label, cmap='viridis')
plt.title('Dataset')
# plt.show()

2.3 模型预测

# 模型预测
# 创建新样本
x_new1 = np.array([[0, -1]])
x_new2 = np.array([[2, 1]])

## 在训练集和测试集上分布利用训练好的模型进行预测
y_label_predict1 = tree_clf.predict(x_new1)
y_label_predict2 = tree_clf.predict(x_new2)

print('The New point 1 predict class:\n',y_label_predict1)
print('The New point 2 predict class:\n',y_label_predict2)
The New point 1 predict class:
 [1]
The New point 2 predict class:
 [0]

3 数据分析

选择企鹅数据(palmerpenguins)进行数据分析练习。该数据集一共包含8个变量,其中7个特征变量,1个目标分类变量,共有150个样本。目标变量为企鹅的三个亚属,分别是(Adélie, Chinstrap and Gentoo)。7个特征变量包含企鹅的七个特征,分别是所在岛屿,嘴巴长度,嘴巴深度,脚蹼长度,身体体积,性别以及年龄。

3.1 数据读取

data = pd.read_csv('penguins_raw.csv')
# 为了简化过程我们选取四个特征进行分析
data = data[['Species', 'Culmen Length (mm)', 'Culmen Depth (mm)', 'Flipper Length (mm)', 'Body Mass (g)']]
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 5 columns):
Species                344 non-null object
Culmen Length (mm)     342 non-null float64
Culmen Depth (mm)      342 non-null float64
Flipper Length (mm)    342 non-null float64
Body Mass (g)          342 non-null float64
dtypes: float64(4), object(1)
memory usage: 13.5+ KB

3.2 缺失值处理

通过对数据观察可知,数据集中存在的缺失值为整行缺失,且仅有两条数据缺失,我们可选择将其直接删去,处理过后的数据如下

data = data.dropna()
print(data.info())
<class 'pandas.core.frame.DataFrame'>
Int64Index: 342 entries, 0 to 343
Data columns (total 5 columns):
Species                342 non-null object
Culmen Length (mm)     342 non-null float64
Culmen Depth (mm)      342 non-null float64
Flipper Length (mm)    342 non-null float64
Body Mass (g)          342 non-null float64
dtypes: float64(4), object(1)
memory usage: 16.0+ KB

3.3 信息查看

# 查看企鹅都有哪些类别
print("企鹅的种类有:\n",data['Species'].unique())

# 查看每一类别各有多少数量
print("每类企鹅数量分别为:\n", data['Species'].value_counts())

# 查看企鹅数据集的一些统计数据
print("统计数据:\n",data.describe())
企鹅的种类有:
 ['Adelie Penguin (Pygoscelis adeliae)' 'Gentoo penguin (Pygoscelis papua)'
 'Chinstrap penguin (Pygoscelis antarctica)']
 
每类企鹅数量分别为:
 Adelie Penguin (Pygoscelis adeliae)          151
Gentoo penguin (Pygoscelis papua)            123
Chinstrap penguin (Pygoscelis antarctica)     68
Name: Species, dtype: int64

统计数据:
        Culmen Length (mm)  Culmen Depth (mm)  Flipper Length (mm)    Body Mass (g)  
count          342.000000         342.000000           342.000000   			342.000000  
mean            43.921930          17.151170           200.915205   			4201.754386  
std          	    5.459584          	 1.974793            14.061714   			801.954536  
min             32.100000          13.100000           172.000000   			2700.000000  
25%             39.225000          15.600000           190.000000   			3550.000000  
50%             44.450000          17.300000           197.000000      		4050.000000  
75%             48.500000          18.700000           213.000000   		 	4750.000000  
max             59.600000          21.500000           231.000000      		6300.000000  
# 为了方便我们将企鹅类别转化为数字0,1,2
data['Species_num'] = pd.factorize(data['Species'])[0]

3.4 可视化描述

# 特征与标签组合的散点可视化
sns.pairplot(data=data, diag_kind='hist', hue= 'Species')
plt.show()

在这里插入图片描述

4 建模预测

4.1 二分类预测任务

4.1.1 构建决策树模型

from sklearn.model_selection import train_test_split
# 选择其类别为0和1的样本 (不包括类别为2的样本)
data_features = data[data['Species_num'].isin([0,1])][['Culmen Length (mm)','Culmen Depth (mm)',
            'Flipper Length (mm)','Body Mass (g)']]
data_target = data[data['Species_num'].isin([0,1])]['Species_num']
# 训练集与测试集8/2分
x_train, x_test, y_train, y_test = train_test_split(data_features, data_target, test_size = 0.2, random_state = 2020)

# 构建决策树模型
from sklearn.tree import DecisionTreeClassifier
dtree = DecisionTreeClassifier(criterion='entropy')

# 在训练集上训练决策树模型
dtree.fit(x_train, y_train)

train_test_split():是交叉验证中常用函数,从样本中随机的按比例选取train data和test data
X_train,X_test, y_train, y_test =train_test_split(train_data,train_target,test_size,random_state=0)

  • train_data:所要划分的样本特征集
  • train_target:所要划分的样本结果
  • test_size:样本占比,如果是整数的话就是样本的数量
  • random_state:是随机数的种子。

4.1.2 模型预测

# 利用模型进行预测
train_predict = dtree.predict(x_train)
test_predict = dtree.predict(x_test)

from sklearn import metrics
# 利用accuracy(准确度)评估模型效果
print('The accuracy of the DecisionTree Model is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the DecisionTree Model is:',metrics.accuracy_score(y_test,test_predict))
# 利用混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)
The accuracy of the DecisionTree Model is: 1.0
The accuracy of the DecisionTree Model is: 0.9818181818181818

The confusion matrix result:
 [[29  1]
 [ 0 25]]

第一组预测用的x_train,即训练集中特征,所得预测结果正确率为百分百
第二组为真实预测,可得预测效果也比较好,模型就还挺好~

accuracy_score所有分类正确的百分比

  • accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)
  • normalize:默认值为True,返回正确分类的比例;如果为False,返回正确分类的样本数

4.1.3 结果可视化

# 利用热力图对于结果进行可视化
sns.heatmap(confusion_matrix_result,annot=True)
# annot:默认False;True表示在热力图每个方格写入数据
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

sns.heatmap()其他参数可参考这篇博文,总结的很全面。
在这里插入图片描述
从热力图中能明显看出所有样本都预测准确了,准确度为1
优秀~

4.2 三分类(多分类)预测任务

三分类(多分类)模型大致上与二分类同理

4.2.1 构建决策树模型

# 同理8/2分
x_train, x_test, y_train, y_test = train_test_split(data[['Culmen Length (mm)','Culmen Depth (mm)', 'Flipper Length (mm)','Body Mass (g)']], data[['Species']], test_size = 0.2, random_state = 2020)

# 训练决策树模型
dtree3 = DecisionTreeClassifier()
dtree3.fit(x_train, y_train)

4.2.2 模型预测

# 模型预测
train_predict = dtree3.predict(x_train)
test_predict = dtree3.predict(x_test)
# 利用 predict_proba() 函数预测其概率
train_predict_proba = dtree3.predict_proba(x_train)
test_predict_proba = dtree3.predict_proba(x_test)
print('The test predict Probability of each class:\n',test_predict_proba)

# 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the DecisionTree Model is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the DecisionTree Model is:',metrics.accuracy_score(y_test,te
The test predict Probability of each class:  # 每一列分别代表一类的预测概率
 [[0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 … … …
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]]
 
The accuracy of the DecisionTree Model is: 1.0
The accuracy of the DecisionTree Model is: 0.9565217391304348

4.2.3 结果可视化

# 查看混淆矩阵
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
sns.heatmap(confusion_matrix_result, annot=True)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

在这里插入图片描述

2020.08.22
TBC…

参与评论 您还未登录,请先 登录 后发表或查看评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:深蓝海洋 设计师:CSDN官方博客 返回首页

打赏作者

baekii

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值