机器学习之决策树
决策树简介
学习目标:
-
理解决策树算法的基本思想
-
知道构建决策树的步骤
决策树是什么?
决策树是一种树形结构,树中每个内部节点表示一个特征上的判断,每个分支代表一个判断结果的输出,每个叶子节点代表一种分类结果
决策树是一种常用的机器学习算法和数据挖掘工具,是一种监督学习算法,英文是Decision Tree
决策树思想的来源非常朴素,试想每个人的大脑都有类似于if-else这样的逻辑判断,这其中的if表示的是条件,if之后的else就是一种选择或决策。程序设计中的条件分支结构就是if-else结构,最早的决策树就是利用这类结构分割数据的一种分类学习方法。
例子:
其主要作用体现在以下几个方面
-
分类(Classification):
-
这是决策树最核心的作用之一。它可以将数据集中的样本划分到预定义的类别中。
-
例如:根据用户的年龄、收入、职业等信息,判断该用户是否会购买某款产品(是/否);根据病人的症状和检查结果,诊断其患有的疾病类型。
-
-
回归(Regression):
-
决策树不仅可以用于分类,还可以用于预测连续数值型的目标变量。
-
例如:根据房屋的面积、位置、房龄等信息,预测房屋的价格;根据历史销售数据预测未来的销售额。
-
-
特征选择与重要性评估:
-
决策树在构建过程中会根据信息增益、基尼不纯度等指标选择最优的特征进行分裂。这使得它能够自动识别出对预测目标最有影响力的特征。
-
作用:帮助理解数据,识别关键因素,可用于后续模型的特征工程。
-
-
规则提取与知识发现:
-
决策树的结构直观,从根节点到叶节点的每一条路径都可以解释为一条“如果...那么...”的规则。
-
作用:便于非专业人士理解模型的决策逻辑,有助于从数据中发现潜在的规律和知识,具有很强的可解释性。
-
-
处理混合类型数据:
-
决策树能够自然地处理数值型和类别型特征,无需像某些算法那样进行复杂的预处理(如必须将所有特征转换为数值)。
-
-
处理非线性关系:
-
决策树通过分段的方式对特征空间进行划分,能够捕捉特征与目标变量之间的非线性关系,无需事先假设数据的分布。
-
-
数据探索与预处理:
-
通过观察决策树的分裂过程,可以发现数据中的异常值、缺失值模式以及特征之间的相互作用,为数据清洗和预处理提供指导。
-
总结来说,决策树的主要作用是提供一个直观、易于理解的模型来进行分类和回归预测,同时具备强大的可解释性和自动特征选择能力,是数据分析和机器学习中非常实用的工具。
决策树的建立过程:
1.特征选择:选取有较强分类能力的特征。
2.决策树生成:根据选择的特征生成决策树。
3.决策树也易过拟合,采用剪枝的方法缓解过拟合。
决策树的种类有什么?
有三种 : ID3树,C4.5树,CART树(CART分类树和CART回归树)
ID3树
目标:
1.理解信息熵的意义
2.理解信息增益的作用
3.知道ID3树的构建流程
ID3树是基于信息增益构建的决策树,如果特征的信息增益最大,那么该特征就作为节点
信息熵定义:熵在信息论中代表随机变量不确定度的度量,熵越大,数据的不确定性越高,熵越小,数据库的不确定性月底
公式:
举例子:
信息增益:特征a对训练数据集D的信息增益g(D,a)定义为集合的熵H(D)与特征a给定条件下D的熵H(D|a)之差
条件熵公式:
D^V/D是指在特征列中该样本的占有率(占比)
例子:
先看“是否出门”的分布:
-
总样本数:5
-
出门(是):3 次(1 晴天 + 2 阴天)
-
不出门(否):2 次(1 晴天 + 1 雨天)
所以:
-
P(是)=35=0.6P(是)=53=0.6
-
P(否)=25=0.4P(否)=52=0.4
信息熵公式:
H=−P(是)log2P(是)−P(否)log2P(否)H=−P(是)log2P(是)−P(否)log2P(否)
查一下(或计算):
-
log20.6≈−0.737log2(0.6)≈−0.737
-
log20.4≈−1.322log2(0.4)≈−1.322
代入:
H=−0.6×(−0.737)−0.4×(−1.322)=0.6×0.737+0.4×1.322≈0.442+0.529=0.971H=−0.6×(−0.737)−0.4×(−1.322)=0.6×0.737+0.4×1.322≈0.442+0.529=0.971
原始熵 H(S) ≈ 0.971 bit
第二步:按“天气”划分,计算每个子集的熵
我们将数据按“天气”分为三组:晴天、阴天、雨天。
-
晴天组(2 个样本)
-
是:1,否:1
-
P(是)=0.5,P(否)=0.5P(是)=0.5,P(否)=0.5
H晴=−0.5log2(0.5)−0.5log2(0.5)=0.5+0.5=1.0H晴=−0.5log2(0.5)−0.5log2(0.5)=0.5+0.5=1.0
样本数占比:25=0.452=0.4
-
阴天组(2 个样本)
-
是:2,否:0
-
P(是)=1,P(否)=0P(是)=1,P(否)=0
H阴=−1log21−0=0H阴=−1log21−0=0
样本数占比:25=0.452=0.4
-
雨天组(1 个样本)
-
否:1
-
P(否)=1P(否)=1
H雨=−1log21=0H雨=−1log21=0
样本数占比:15=0.251=0.2
第三步:计算条件熵
H天气=(晴占比)×H晴+(阴占比)×H阴+(雨占比)×H雨H天气=(晴占比)×H晴+(阴占比)×H阴+(雨占比)×H雨
=0.4×1.0+0.4×0+0.2×0=0.4+0+0=0.4=0.4×1.0+0.4×0+0.2×0=0.4+0+0=0.4
条件熵熵 = 0.4
第四步:计算信息增益
*信息增益 = 划分前的熵 - 条件熵
IG=H(S)−H天气=0.971−0.4=0.571I**G=H(S)−H天气=0.971−0.4=0.571
信息增益 ≈ 0.571 bit
结果解释
-
原始数据很不确定(熵 ≈ 0.971)
-
用“天气”来划分后,整体混乱程度大幅下降(加权熵降到 0.4)
-
所以,“天气”这个特征带来了 0.571 的信息增益
信息增益越大,说明这个特征越有用!
在这个例子中,0.571 的增益是很大的,说明“天气”是一个非常好的判断“是否出门”的特征。
ID3树构建流程
构建流程:
-
计算每个特征的信息增益
-
使用信息增益最大的特征将数据集 S 拆分为子集
-
使用该特征(信息增益最大的特征)作为决策树的一个节点
-
使用剩余特征对子集重复上述(1,2,3)过程
C4.5树
学习目标:理解信息增益率的意义,知道C4.5树的构建方法
C4.5 决策树的出现,是为了改进早期决策树算法(如 ID3)的缺陷,让它更强大、更实用,能处理真实世界中复杂多样的数据。
特点:
1.缓解了ID3分支过程中总喜欢偏向选择值较多的属性
2.可处理连续数值型属性,也增加了对缺失值的处理方法
3.只适合于能够驻留于内存的数据集,大数据集无能为力
C4.5树的关键信息是信息增益率,如果特征的信息增益率越大,那么这个特征就作为节点
信息增益率:
通俗的讲就是 信息增益率=信息增益/特征熵
-
Gain_Ratio 表示信息增益率
-
IV 表示分裂信息、内在信息(特征熵)
-
特征的信息增益 ➗ 内在信息
-
如果某个特征的特征值种类较多,则其内在信息值就越大。即:特征值种类越多,除以的系数就越大。
-
如果某个特征的特征值种类较小,则其内在信息值就越小。即:特征值种类越小,除以的系数就越小。
-
信息增益比本质: 是在信息增益的基础之上乘上一个惩罚参数。特征个数较多时,惩罚参数较小;特征个数较少时,惩罚参数较大。惩罚参数:数据集D以特征A作为随机变量的熵的倒数。
例子:
信息增益率本质:
1.特征的信息增益/特征的内在信息
2.相当于对信息增益进行修正,增加一个惩罚系数
3.特征取值个数较多时,惩罚系数较小,特征取值个数较少时,惩罚系数较大
4.惩罚系数:数据集D以特征a作为随机变量的熵的导数
CART决策树(分类)
Cart模型是一种决策树模型,它即可以用于分类,也可以用于回归。
分类和回归树模型采用不同的最优化策略。Cart回归树使用平方误差最小化策略,Cart分类生成树采用的基尼指数最小化策略。
基尼值:
基尼指数:
基尼指数值越小(cart),则说明优先选择该特征为节点。
例子:
场景:根据“是否下雨”决定“是否带伞”
我们有 6 条记录:
是否下雨 | 是否带伞 |
---|---|
是 | 是 |
是 | 是 |
是 | 否 |
否 | 否 |
否 | 否 |
否 | 是 |
我们想用“是否下雨”这个属性来划分数据,看看它的基尼指数是多少。
第一步:计算整体基尼指数(不分组)
总数据:6 条
-
带伞(是):3 条
-
不带伞(否):3 条
概率:
-
p是=3/6=0.5p是=3/6=0.5
-
p否=3/6=0.5p否=3/6=0.5
基尼指数:
第二步:按“是否下雨”分组,计算加权基尼指数
-
下雨时(3 条数据)
-
带伞:2 条
-
不带伞:1 条
概率:
-
p是=2/3p是=2/3, p否=1/3p否=1/3
基尼:
-
不下雨时(3 条数据)
-
带伞:1 条
-
不带伞:2 条
概率:
-
p是=1/3p是=1/3, p否=2/3p否=2/3
基尼:
第三步:计算加权基尼指数
每组都是 3 条,总 6 条,所以权重都是 3/6=0.5
基尼指数=0.5×0.444+0.5×0.444=0.444
''是否下雨”的基尼指数 ≈ 0.444
当CART决策树分类中的特征出现了连续的值
例子:
找到基尼指数最小的中点作为进行分类,如果跟其他特征的基尼指数相比也是最小的,那么就可以作为节点
API介绍
class sklearn.tree.DecisionTreeClassifier(criterion=’gini’, max_depth=None,random_state=None)
CART决策树(回归)
CART 回归树和 CART 分类树的不同之处在于:
-
CART 分类树预测输出的是一个离散值,CART 回归树预测输出的是一个连续值。
-
CART 分类树使用基尼指数作为划分、构建树的依据,CART 回归树使用平方损失。
-
分类树使用叶子节点里出现更多次数的类别作为预测类别,回归树则采用叶子节点里均值作为预测输出
CART回归树的平方损失
CART回归树的构建过程:
1 选择一个特征,将该特征的值进行排序,取相邻点计算均值作为待划分点
2 根据所有划分点,将数据集分成两部分:R1、R2
3 R1 和 R2 两部分的平方损失相加作为该切分点平方损失
4 取最小的平方损失的划分点,作为当前特征的划分点
5 以此计算其他特征的最优划分点、以及该划分点对应的损失值
6 在所有的特征的划分点中,选择出最小平方损失的划分点,作为当前树的分裂点
决策树剪枝
为什么要剪枝?
决策树剪枝是一种防止决策树过拟合的一种正则化方法;提高其泛化能力。
剪枝:把子树的节点全部删掉,使用用叶子节点来替换
剪枝方法:
1.预剪枝:指在决策树生成过程中,对每个节点在划分前先进行估计,若当前节点的划分不能带来决策树泛化性能提升,则停止划分并将当前节点标记为叶节点;
优点:预剪枝使决策树地很多分支没有展开,不单降低了过拟合风险,还显著减少了决策树地训练,测试时间开销
缺点:有些分支地当前划分虽不能提高泛化性能,但后续划分却有可能导致性能地显著提高,预剪枝决策树也带来了欠拟合地风险
2.后剪枝:是先从训练集生成一棵完整的决策树,然后自底向上地对非叶节点进行考察,若将该节点对应的子树替换为叶节点能带来决策树泛化性能提升,则将该子树替换为叶节点。
优点:比预剪枝保留了更多的分支,一般情况下,后剪枝决策树地欠拟合风险很小,泛化性能往往优于预剪枝
缺点:后剪枝先生成,后剪枝,自底向上地对树中所有非子叶节点进行逐一考察,训练时间开销比为未剪枝地决策树和预剪枝地决策树都要大得多
案例:泰坦尼克号生存案例
泰坦尼克号沉没是历史上最著名的沉船事件。1912年4月15日,在她的处女航中,泰坦尼克号在与冰山相撞后沉没,在2224名乘客和船员中造成1502人死亡。这场耸人听闻的悲剧震惊了国际社会,并为船舶制定了更好的安全规定。 造成海难失事的原因之一是乘客和船员没有足够的救生艇。尽管幸存下来有一些运气因素,但有些人比其他人更容易生存,例如妇女,儿童和社会地位较高的人群。 在这个案例中,我们要求您完成对哪些人可能存活的分析。
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from pandas.core.common import random_state
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, \
classification_report
from sklearn.tree import plot_tree
from sklearn.tree import DecisionTreeClassifier
# 1读取数据
data = pd.read_csv('../data/train.csv')
# 2数据预处理
# data.info()
data = pd.get_dummies(data,columns=['Sex'])
data.drop(['Sex_male'],axis=1,inplace=True)
# data.dropna(axis=1,how='all',inplace=True)
print(data.head())
x=data[['Sex_female','Age','SibSp','Parch','Fare']]
x=x.fillna(x.mean())
# x = x.interpolate(method='linear', axis=0)
y=data['Survived']
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=520)
# 3.特征工程,标准化
transfer = StandardScaler()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
# 4模型训练
# es=LogisticRegression()
# 默认使用gini值,那么就是CART树
es = DecisionTreeClassifier()
es.fit(x_train,y_train)
# 5.模型预测
y_predict=es.predict(x_test)
# 6模型评估
print(f"准确率:{es.score(x_test,y_test)}")
print(f"分类评估报告:{classification_report(y_test,y_predict,target_names=['Died','Survivor'])}")
# print(f"精确率:{precision_score(y_test,y_predict)}")
# print(f"召回率:{recall_score(y_test,y_predict)}")
# print(f"F1值:{f1_score(y_test,y_predict)}")
y_proba= es.predict_proba(x_test)[:,1]
# print(f"roc曲线:{roc_auc_score(y_test,y_proba)}")
# 6.决策树可视化
plt.figure(figsize=(50,30))
# 参数1:决策树类分期,参2:是否填充颜色,参3:最大深度,参4:特征名称,参5:类别名称
plot_tree(es,filled=True,max_depth=10,feature_names=['Sex_female','Age','SibSp','Parch','Fare'],class_names=['died','survived'])
plt.savefig('../data/dec_tree.png')
plt.show()
"""
PassengerId Survived Pclass ... Cabin Embarked Sex_female
0 1 0 3 ... NaN S False
1 2 1 1 ... C85 C True
2 3 1 3 ... NaN S True
3 4 1 1 ... C123 S True
4 5 0 3 ... NaN S False
[5 rows x 12 columns]
准确率:0.7932960893854749
分类评估报告: precision recall f1-score support
Died 0.75 0.90 0.82 93
Survivor 0.87 0.67 0.76 86
accuracy 0.79 179
macro avg 0.81 0.79 0.79 179
weighted avg 0.81 0.79 0.79 179
"""
案例:线性回归与决策树对比
""" 结论: 回归类的问题。即能使用线性回归,也能使用决策树回归 优先使用线性回归,因为决策树回归可能比较容易导致过拟合 """
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeRegressor #回归决策树
from sklearn.linear_model import LinearRegression #线性回归
import matplotlib.pyplot as plt
#1、获取数据
x = np.array(list(range(1, 11))).reshape(-1, 1)
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05])
#2、创建线性回归 和 决策树回归
es1=LinearRegression()
es2=DecisionTreeRegressor(max_depth=1)
es3=DecisionTreeRegressor(max_depth=10)
#3、模型训练
es1.fit(x,y)
es2.fit(x,y)
es3.fit(x,y)
#4、准备测试数据 ,用于测试
# 起始, 结束, 步长.
x_test = np.arange(0.0, 10.0, 0.1).reshape(-1, 1)
print(x_test)
#5、模型预测
y_predict1=es1.predict(x_test)
y_predict2=es2.predict(x_test)
y_predict3=es3.predict(x_test)
#6、绘图
plt.figure(figsize=(10,5))
#散点图
plt.scatter(x,y,color='gray',label='data')
plt.plot(x_test,y_predict1,color='g',label='liner regression')
plt.plot(x_test,y_predict2,color='b',label='max_depth=1')
plt.plot(x_test,y_predict3,color='r',label='max_depth=10')
plt.legend()
plt.xlabel("data")
plt.ylabel("target")
plt.show()
更新日期:2025年9月4日