5. 决策树算法 API
学习目标:
- 知道决策树算法 API 的具体使用
sklearn.tree.DecisionTreeClassifier(criterion='gini', splitter='best',
max_depth=None, min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_features=None, random_state=None,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
class_weight=None, ccp_alpha=0.0)
-
作用:
sklearn.tree.DecisionTreeClassifier
是一个决策树分类器,它可以用来进行分类任务。决策树是一种非参数监督学习方法,它通过从数据特征中学习简单的决策规则来预测目标变量的值。决策树可以被看作是一个分段常数近似。 -
参数:
criterion
: 用于测量分裂质量的函数。支持的标准有:- “gini”(基尼不纯度)是默认参数,即 CART 算法
- “entropy”(香农信息增益)
splitter
: 用于选择每个节点分裂的策略。支持的策略有“best”(选择最佳分裂)和“random”(选择最佳随机分裂)。max_depth
: 决策树的最大深度。如果为None
,则节点会扩展直到所有叶子都是纯净的或者直到所有叶子都包含少于min_samples_split
个样本(如果不输入的话,决策树在建立子树的时候不会限制子树的深度)。- 一般来说,数据少或者特征少的时候可以不管这个值。如果模型样本量多,特征也多的情况下,推荐限制这个最大深度,具体的取值取决于数据的分布。常用的可以取值 10-100 之间。
min_samples_split
: 分裂内部节点所需的最小样本数(内部节点再划分所需最小样本数)。- 这个值限制了子树继续划分的条件,如果某节点的样本数少于
min_samples_split
,则不会继续再尝试选择最优特征来进行划分。默认是 2。 - 如果样本量不大,不需要管这个值。如果样本量数量级非常大,则推荐增大这个值。10 万样本项目在建立决策树时,选择了
min_samples_split=10
,仅供参考。
- 这个值限制了子树继续划分的条件,如果某节点的样本数少于
min_samples_leaf
: 叶节点所需的最小样本数。- 这个值限制了叶子节点最少的样本数,如果某叶子节点数目小于样本数,则会和兄弟节点一起被剪枝。默认是 1,可以输入最少的样本数的整数,或者最少样本数占样本总数的百分比。如果样本量不大,不需要管这个值。
- 如果样本量数量级非常大,则推荐增大这个值。10 万样本使用 min_samples_leaf 的值为 5,仅供参考。
min_weight_fraction_leaf
: 叶节点所需的最小加权样本数占总权重和的比例。max_features
: 寻找最佳分裂时要考虑的特征数量。random_state
: 随机数生成器种子。max_leaf_nodes
: 最大叶节点数。min_impurity_decrease
: 如果节点分裂会导致杂质的减少大于或等于这个值,则该节点将被分裂。class_weight
: 类别权重。
-
返回值:返回一个决策树分类器对象,可以使用它来拟合数据(
fit
)、预测数据(predict
)以及进行其他操作。
6. 案例:泰坦尼克号乘客生存预测
学习目标:
- 通过案例进一步掌握决策树算法 API 的具体使用
6.1 案例背景
泰坦尼克号沉没是历史上最臭名昭着的沉船之一。1912 年 4 月 15 日,在她的处女航中,泰坦尼克号在与冰山相撞后沉没,在 2224 名乘客和机组人员中造成 1502 人死亡。这场耸人听闻的悲剧震惊了国际社会,并为船舶制定了更好的安全规定。造成海难失事的原因之一是乘客和机组人员没有足够的救生艇。尽管幸存下沉有一些运气因素,但有些人比其他人更容易生存,例如妇女,儿童和上流社会。
背景中提到的“在 2224 名乘客和机组人员中造成 1502 人死亡”这一数据并不准确。根据维基百科,泰坦尼克号上共有 2224 人,其中包括乘客和机组人员,而死亡人数在 1490-1635 人之间。
在这个案例中,我们要求完成对哪些人可能存活的分析。要求运用机器学习工具来预测哪些乘客幸免于悲剧。
我们提取到的数据集中的特征包括票的类别,是否存活,乘坐班次,年龄,家庭住址/目的地,房间,船和性别等。
数据(目前无法访问):http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt
数据(可以访问,但有略微出入):https://github.com/YBIFoundation/Dataset/blob/main/Titanic.txt
属性说明:
- pclass:客舱等级(
1
,2
,3
) - survived:是否幸存(
0
,1
) - name:姓名
- sex:性别
- age:年龄
- sibsp:船上兄弟姐妹/配偶的数量
- parch:船上父母/子女的数量
- ticket:船票号码
- fare:船票价格
- cabin:客舱号码
- embarked:登船港口
- boat:救生艇编号
- body:遗体识别号码
- home.dest:家庭住址/目的地
经过观察数据得到:
- pclass:客舱等级(
1
,2
,3
)是社会经济阶层的代表 - 其中 age 数据存在缺失
6.2 步骤分析
- 获取数据
- 数据基本处理
- 确定特征值和目标值
- 缺失值处理
- 数据集划分
- 特征工程(字典特征抽取)
- 机器学习(决策树)
- 模型评估
6.3 代码实现
零、导入模块
import pandas as pd
import numpy as np
from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz
1. 获取数据
# 1. 获取数据
titanic = pd.read_csv("../data/titanic.txt")
titanic
titanic.describe()
2. 数据基本处理
2.1 确定特征值和目标值
我们先对各个属性进行分析,以确定哪些数据作为我们的特征值,哪些数据作为我们的目标值:
特征 | 描述 | 所属 |
---|---|---|
⭐️pclass | 客舱等级(1 , 2 , 3 ) | 可能有用,设为特征值 |
⭐️survived | 是否幸存(0 , 1 ) | 很明显是目标值 |
name | 姓名 | 没关系 |
⭐️sex | 性别 | 材料中提到性别,作为特征值 |
⭐️age | 年龄 | 材料中提到年龄,作为特征值 |
sibsp | 船上兄弟姐妹/配偶的数量 | 没关系 |
parch | 船上父母/子女的数量 | 没关系 |
ticket | 船票号码 | 没关系 |
fare | 船票价格 | 没关系 |
cabin | 客舱号码 | 没关系 |
embarked | 登船港口 | 没关系 |
boat | 救生艇编号 | 没关系 |
body | 遗体识别号码 | 没关系 |
home.dest | 家庭住址/目的地 | 没关系(在沉没之前并没有旅客下船) |
## 2.1 确定特征值和目标值
x = titanic[["pclass", "sex", "age"]]
y = titanic["survived"]
2.2 缺失值处理
## 2.2 缺失值处理
# 缺失值需要处理,将特征中有类别的特征进行字典特征抽取
# 因为缺失值为 N/A,所以我们可以直接使用 .isnull() 方法来判断是否存在缺失值
x.loc[x['age'].isnull(), 'age'] = x['age'].mean()
这行代码使用
.loc
来修改 DataFrame 中的数据。.loc
是一个索引器,它允许我们通过标签来访问 DataFrame 中的数据。在这个例子中,
.loc
的第一个参数是x['age'].isnull()
,它返回一个布尔 Series,表示每一行的 ‘age’ 列是否为缺失值。第二个参数是'age'
,表示我们要修改 ‘age’ 列的值。因此,这行代码的作用是将 ‘age’ 列中的缺失值替换为 ‘age’ 列的平均值。
注意:这里不能使用 df.dropna
,因为我们有很多值是 N/A,如果丢弃,那么对应的行也就删除了,这样会导致我们删除大量数据。所以我们要使用替换而不是删除。
详情请参考 [学习笔记] [机器学习] 1. 机器学习前置知识(机器学习概述、Matplotlib、Numpy、Pandas)中关于缺失值的处理。
N/A 是 “Not Applicable” 的缩写,意思是“不适用”。它通常用于表格或问卷中,表示某个问题对于填写者来说不适用,无法回答。有时也可以表示“无”或“没有”。
2.3 数据集划分
## 2.3 数据集划分
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=22)
3. 特征工程(字典特征抽取)
特征中出现类别符号,需要进行 one-hot 编码处理(使用 DictVectorizer
类来实现),但 DictVectorizer
类要求输入是一个字典,而我们现在的数据是 DataFrame
,所以我们需要将 DataFrame
转换为 Dict
,这里使用到了df.to_dict()方法
。
x.to_dict(orient="records") # 需要将数组特征转换成字典数据。
df.to_dict(orient="records")
是一个将 DataFrame 转换为字典的方法。它使用 orient
参数来指定转换的方式。
当 orient="records"
时,每一行数据都会被转换为一个字典,其中键(key)是列名,值(value)是该行对应列的值。最终,所有行的字典都会被组合成一个列表,作为结果返回。
例如,假设我们有一个 DataFrame 如下:
A B
0 1 2
1 3 4
当我们调用 df.to_dict(orient="records")
时,会得到以下结果:
[{'A': 1, 'B': 2}, {'A': 3, 'B': 4}]
# 3. 特征工程(字典特征抽取)
## 3.1 实例化一个字典转换器类
transfer = DictVectorizer(sparse=False) # 不用输出稀疏矩阵
## 3.2 将DataFrame转换为字典数据
x_train = x_train.to_dict(orient="records")
x_test = x_test.to_dict(orient="records")
print("转换为字典后的x_train为:", x_train)
print("\r\n转换为字典后的x_test为:", x_test)
## 3.3 特征转换
x_train = transfer.fit_transform(x_train)
# 注意:在测试数据上,应该使用与训练数据相同的转换方式,因此应该使用 `transform` 方法,
# 而不是 `fit_transform` 方法。`transform` 方法只进行转换,不会改变转换器的拟合结果。
x_test = transfer.transform(x_test)
print("\r\nx_train:\r\n", x_train)
print("\r\nx_test:\r\n", x_test)
打印结果:
转换为字典后的x_train为: [{'pclass': 1, 'sex': 'female', 'age': 39.0}, {'pclass': 2, 'sex': 'female', 'age': 19.0}, ...]
转换为字典后的x_test为: [{'pclass': 3, 'sex': 'male', 'age': 11.0}, {'pclass': 3, 'sex': 'male', 'age': 29.881137667304014}, {'pclass': 3, 'sex': 'male', ...]
x_train:
[[39. 1. 1. 0. ]
[19. 2. 1. 0. ]
[27. 3. 0. 1. ]
...
[29.88113767 3. 0. 1. ]
[24. 1. 0. 1. ]
[17. 3. 0. 1. ]]
x_test:
[[11. 3. 0. 1. ]
[29.88113767 3. 0. 1. ]
[ 4. 3. 0. 1. ]
...
[27. 2. 1. 0. ]
[49. 1. 0. 1. ]
[16. 1. 1. 0. ]]
这段代码使用了
DictVectorizer
类来进行字典特征抽取。首先,它将训练数据和测试数据从 DataFrame 转换为字典数据。然后,使用fit_transform
方法将训练数据转换为数值型数据,使用transform
方法将测试数据转换为数值型数据。
打印结果显示了转换后的训练数据和测试数据。转换后的数据是一个二维数组,其中每一行代表一个样本,每一列代表一个特征。例如,在训练数据中,第一行
[39. 1. 1. 0. ]
表示第一个样本的年龄为 39 岁,pclass 为 1,性别为女性(sex_female 为 1,sex_male 为 0)。
4. 机器学习(决策树)
决策树 API 当中,如果没有指定 max_depth
,那么会根据信息熵的条件直到最终结束。这里我们可以指定树的深度来进行限制树的大小。
# 4. 机器学习
## 4.1 定义模型
estimator = DecisionTreeClassifier(criterion="entropy", max_depth=5)
# 4.2 模型训练
estimator.fit(x_train, y_train)
print("模型训练完成!")
5. 模型评估
# 5. 模型评估
score = estimator.score(x_test, y_test)
print(f"模型准确率为:{score * 100:.2f}%")
res = estimator.predict(x_test)
print("\r\n测试集预测结果为:\r\n", res)
结果:
模型准确率为:75.61%
测试集预测结果为:
[0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 0 1 1 0 0 0 1 1 1 1
0 0 0 1 0 0 0 0 1 1 0 1 0 1 1 0 0 0 1 1 1 0 0 1 0 1 0 0 0 0 0 0 1 0 1 1 0
0 0 0 1 0 1 1 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 1 1 0 1 0 1 0 0
0 1 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 1 0 0
1 0 0 0 1 1 0 0 1 0 0 1 0 0 1 1 1 0 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 0 1 0 0
0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 1 1 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 1 0
1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 1 1 1 1 0 1 1 0 0 0 0 1 0 0 0 0 0 0
0 0 0 0 1 0 1 1 1 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 1 0 1 0 0 0 0 1 1 0 0 1
0 0 1 1 1 0 0 0 1 0 0 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1 1 0 1]
【参数调整】更改 criterion
(测量分裂质量的函数)
上面模型我们用的 criterion = entropy
,下面我们使用 gini
看看模型的效果。
# 4. 机器学习
## 4.1 定义模型
estimator = DecisionTreeClassifier(criterion="gini", max_depth=5)
# 4.2 模型训练
estimator.fit(x_train, y_train)
print("模型训练完成!\r\n")
# 5. 模型评估
score = estimator.score(x_test, y_test)
print(f"模型准确率为:{score * 100:.2f}%")
res = estimator.predict(x_test)
print("\r\n测试集预测结果为:\r\n", res)
结果:
模型训练完成!
模型准确率为:75.30%
测试集预测结果为:
[0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 0 1 1 0 0 0 1 1 0 1
0 0 0 1 0 0 0 0 1 1 0 1 0 1 1 0 0 0 1 1 1 0 0 1 0 1 0 0 0 0 0 0 1 0 1 1 0
0 0 0 1 0 1 1 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 1 1 0 0 0 1 0 0
0 0 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 1 0 0
1 0 0 0 1 1 0 0 1 0 0 1 0 0 1 1 1 0 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 0 1 0 0
0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 0 0 1 0
1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 1 0 1 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0
0 0 1 0 1 0 1 0 1 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 0 1
0 0 1 1 1 0 0 0 1 0 0 1 0 0 0 1 0 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1]
criterion | 模型准确率 |
---|---|
Entropy | 75.61% |
Gini | 75.30% |
说明在此模型超参数条件下,使用 Entropy 作为测量分裂质量的函数是更加适合的。
【参数调整】更改决策树深度 max_depth
import matplotlib.pyplot as plt
from pylab import mpl
# 设置中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
# 设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False
# 4. 机器学习
tree_depth = list(range(1, 101))
acc_test_lst = []
acc_train_lst = []
for depth in tree_depth:
## 4.1 定义模型
estimator = DecisionTreeClassifier(criterion="gini", max_depth=depth)
# 4.2 模型训练
estimator.fit(x_train, y_train)
# 5. 模型评估
score_test = estimator.score(x_test, y_test)
score_train = estimator.score(x_train, y_train)
acc_test_lst.append(score_test)
acc_train_lst.append(score_train)
plt.figure(dpi=300)
plt.plot(tree_depth, acc_test_lst, label="测试集")
plt.plot(tree_depth, acc_train_lst, label="训练集")
plt.title("决策树深度与测试集准确率的关系")
plt.xlabel("决策树深度")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
结果:
根据图像可知,训练集准确率和测试集准确率都随着树的深度增加而增加,但是在一定深度后,准确率趋于稳定。
这可能意味着,在树的深度达到一定值后,模型已经足够复杂,能够很好地拟合训练数据。但是,如果继续增加树的深度,可能会导致过拟合问题。
因此我们可以尝试选择一个合适的树的深度,在保证测试集准确率的同时,避免过拟合问题。
6.4 决策树可视化
6.4.1 保存树的结构到 .dot
文件
sklearn.tree.export_graphviz()
作用:用于将决策树导出为 GraphViz DOT 格式的函数。这个函数生成了一个 GraphViz 表示的决策树,然后将其写入 out_file
。一旦导出,就可以使用例如以下命令生成图形渲染:
$ dot -Tps tree.dot -o tree.ps (PostScript 格式)
$ dot -Tpng tree.dot -o tree.png (PNG 格式)
参数:
decision_tree
:要导出为 GraphViz 的决策树。out_file
:输出文件的句柄或名称。如果为 None,则结果作为字符串返回。max_depth
:表示的最大深度。如果为 None,则完全生成树。feature_names
:每个特征的名称。如果为 None,则使用通用名称(“x [0]”,“x [1]”,…)。class_names
:按升序排列的每个目标类的名称。仅与分类相关,不支持多输出。如果为 True,则显示类名的符号表示。label
:是否显示杂质等信息性标签。选项包括:- ‘all’ 在每个节点处显示
- ‘root’ 仅在顶部根节点处显示
- ‘none’ 在任何节点都不显示。
filled
:当设置为 True 时,绘制节点以指示分类的主要类别,回归值的极值或多输出节点的纯度。leaves_parallel
:当设置为 True 时,在树的底部绘制所有叶节点。impurity
:当设置为 True 时,在每个节点处显示杂质。node_ids
:当设置为 True 时,在每个节点上显示 ID 号。proportion
:当设置为 True 时,将 ‘values’ 和/或 ‘samples’ 的显示更改为比例和百分比。rotate
:当设置为 True 时,将树从左到右定向而不是从上到下。rounded
:当设置为 True 时,绘制带有圆角的节点框。special_characters
:当设置为 False 时,忽略特殊字符以实现 PostScript 兼容性。precision
:浮点数在每个节点的杂质、阈值和值属性中的精度位数。- 返回值:如果
out_file
为 None,则返回输入树的 GraphViz 点格式的字符串表示形式。
举例:
feature_names_lst = transfer.get_feature_names_out()
export_graphviz(estimator, out_file="../data/decision_tree.dot",
feature_names=['age', 'pclass', 'sex=女性', 'sex=男性'],
fontname='Microsoft YaHei')
注意:
fontname='Microsoft YaHei'
是为了防止无法显示中文- 如果我们不知道特征的名称,我们可以
transfer.get_feature_names_out()
获取特征的list
此时生成如下文件:
因为我们限制了决策树的深度,只有 5 层,因此这个树很浅,相应的文件也很小。
.dot
文件当中的内容如下:
digraph Tree {
node [shape=box, fontname="Microsoft YaHei"] ;
edge [fontname="Microsoft YaHei"] ;
0 [label="sex=男性 <= 0.5\nentropy = 0.955\nsamples = 981\nvalue = [613, 368]"] ;
1 [label="pclass <= 2.5\nentropy = 0.829\nsamples = 340\nvalue = [89, 251]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="pclass <= 1.5\nentropy = 0.351\nsamples = 182\nvalue = [12, 170]"] ;
1 -> 2 ;
3 [label="age <= 8.0\nentropy = 0.242\nsamples = 100\nvalue = [4, 96]"] ;
2 -> 3 ;
4 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0]"] ;
3 -> 4 ;
5 [label="age <= 62.5\nentropy = 0.196\nsamples = 99\nvalue = [3, 96]"] ;
3 -> 5 ;
6 [label="entropy = 0.147\nsamples = 95\nvalue = [2, 93]"] ;
5 -> 6 ;
7 [label="entropy = 0.811\nsamples = 4\nvalue = [1, 3]"] ;
5 -> 7 ;
8 [label="age <= 17.5\nentropy = 0.461\nsamples = 82\nvalue = [8, 74]"] ;
2 -> 8 ;
...
33 -> 35 ;
36 [label="age <= 45.25\nentropy = 0.536\nsamples = 474\nvalue = [416, 58]"] ;
28 -> 36 ;
37 [label="age <= 44.5\nentropy = 0.56\nsamples = 443\nvalue = [385, 58]"] ;
36 -> 37 ;
38 [label="entropy = 0.555\nsamples = 442\nvalue = [385, 57]"] ;
37 -> 38 ;
39 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
37 -> 39 ;
40 [label="entropy = 0.0\nsamples = 31\nvalue = [31, 0]"] ;
36 -> 40 ;
}
6.4.2 可视化展示决策树
以文字的形式来展示决策树很明显是不够直观的,因此可以借助其他工具,这里有两种方式:
- 使用 Graphviz 来可视化显示
.dot
文件。 - 使用第三方网站来可视化显示
.dot
文件。
6.4.2.1 使用 Graphviz 来可视化显示
我们可以使用 Graphviz 将 .dot
文件转换为图像格式,例如 png
或 jpg
。在命令行中输入指令来完成转换。
.png
,.jpg
和.bmp
都是常见的图像文件格式,它们各有优缺点。
.png
(Portable Network Graphics)是一种无损压缩的图像格式,这意味着它可以在压缩图像文件大小的同时保留原始质量和细节。它支持透明度,并且通常用于网络图形和文本图形。.jpg
(Joint Photographic Experts Group)是一种有损压缩的图像格式,它可以将数字图像压缩到较小的文件大小。与.bmp
或.png
文件相比,具有相似质量和分辨率的.jpg
图像可以具有更小的文件大小。但是,由于它使用了有损压缩技术,因此可能会损失一些图像质量。.bmp
(Bitmap)是一种原始且未压缩的图像文件格式。它支持多种颜色深度和可选的 α \alpha α 通道(透明度),但文件大小通常较大。选择哪种格式取决于您对图像的使用需求。例如,如果您需要在网络上使用图像,并且希望文件大小较小且加载时间短,则可以考虑使用
.jpg
格式。如果您需要保留图像的原始质量和细节,则可以考虑使用.png
格式。如果您需要处理原始且未压缩的图像数据,则可以考虑使用.bmp
格式。
举例:
# 生成.png图片
dot -Tpng your_file.dot -o output.png
# 生成.jpg图片
dot -Tjpg your_file.dot -o output.jpg
其中:
-Tpng
和-Tjpg
是 Graphviz 命令行工具中的选项,用于指定输出文件的格式。-Tpng
表示输出文件的格式为.png
-Tjpg
表示输出文件的格式为.jpg
- 可以使用不同的选项来生成不同格式的图像文件
- 例如
-Tgif
用于生成.gif
文件 -Tpdf
用于生成.pdf
文件等
- 例如
your_file.dot
是.dot
文件的名称output.png
是要生成的图片文件的名称
注意:
-T
选项指定的输出文件格式应与输出文件的扩展名保持一致,这样可以确保生成的图像文件格式正确。- 如果
-T
选项指定的输出文件格式与输出文件的扩展名不一致,可能会导致生成的图像文件无法正常打开或显示。因此,建议您确保-T
选项指定的输出文件格式与输出文件的扩展名保持一致。
生成的图片如下所示。
因为我们限制了决策树的深度,只有 5 层,因此这个树很浅
6.4.2.2 使用 第三方网站 来可视化显示
网站地址:WebGraphviz is Graphviz in the Browser
我们可以将刚才决策树的 .dot
文件的内容复制到该网站中以可视化显示,结果如下:
因为我们限制了决策树的深度,只有 5 层,因此这个树很浅
小结:
- 案例流程分析【了解】
- 获取数据
- 数据基本处理
- 确定特征值,目标值
- 缺失值处理
- 数据集划分
- 特征工程(字典特征抽取)
- 机器学习(决策树)
- 模型评估
- 决策树可视化【了解】
- 决策树导出:
sklearn.tree.export_graphviz()
- 决策树可视化
- 命令行:
dot -T图片格式 文件名.dot -o 输出图片名.图片格式
- 第三方网站:WebGraphviz is Graphviz in the Browser
- 命令行:
- 决策树导出:
7. 回归决策树
学习目标:
- 知道回归决策树的实现原理
前面已经讲到,关于数据类型,我们主要可以把其分为两类,①连续型数据和②离散型数据。
在面对不同数据时,决策树也可以分为两大类型:
- 分类决策树:主要用于处理离散型数据
- 回归决策树:主要用于处理连续型数据
连续性数据主要用于回归;离散型数据主要用于分类
7.1 原理概述
不管是回归决策树还是分类决策树,都会存在两个核心问题:
- 如何选择划分结点?
- 如何决定叶节点的输出值?
一个回归树对应着输入空间(即特征空间)的一个划分结点以及在划分单元上的输出值。在分类树中,我们采用信息论中的方法,通过计算选择最佳划分点。而在回归树中,采用的是启发式的方法。
假如我们有 n n n 个特征,每个特征有 s i ( i ∈ ( 1 , n ) ) s_i(i \in (1, n)) si(i∈(1,n)) 个取值,那我们遍历所有特征,尝试该特征所有取值,对空间进行划分,直到取到特征 j j j 的取值 s s s,使得损失函数最小,这样就得到了一个划分点。描述该过程的公式如下:
min j s [ min c 1 L ( y i , c i ) + min c 2 L ( y i , c 2 ) ] \underset{js}{\min}[\underset{c_1}{\min}\ \mathcal{L}(y_i, c_i) + \underset{c_2} {\min} \ \mathcal{L}(y_i, c_2)] jsmin[c1min L(yi,ci)+c2min L(yi,c2)]
其中:
- n n n 表示特征的数量
- s i s_i si 表示第 i i i 个特征的取值数量
- j j j 和 s s s 分别表示最佳划分点的特征和取值
- c 1 c_1 c1 和 c 2 c_2 c2 分别表示划分后两个区域内固定的输出值。
- L \mathcal{L} L 表示损失函数
假设将输入空间划分为 M M M 个单元: R 1 , R 2 , . . . , R m R_1, R_2, ..., R_m R1,R2,...,Rm,那么每个区域的输出值就是 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i|x_i \in R_m) cm=avg(yi∣xi∈Rm),也就是该区域内所有点 y y y 值的平均数。
其中:
- M M M 表示输入空间被划分成的单元数
- R 1 , R 2 , . . . , R m R_1, R_2, ..., R_m R1,R2,...,Rm 表示每个单元
- c m c_m cm 表示每个区域的输出值,它等于该区域内所有点 y y y 值的平均数
- a v g ( y i ∣ x i ∈ R m ) \mathrm{avg}(y_i|x_i \in R_m) avg(yi∣xi∈Rm) 表示在 x i x_i xi 属于区域 R m R_m Rm 的条件下,所有 y i y_i yi 值的平均数。
Q1:“单元”是什么?
A1:在决策树中,单元(也称为区域)是指输入空间被划分成的子区域。决策树通过不断地选择最佳划分点来将输入空间划分成若干个单元,每个单元内的数据点具有相似的特征。每个单元都有一个固定的输出值,用来预测该区域内数据点的目标值。
Q2:划分节点就是结点吗?划分节点可以是叶子结点吗?
A2:划分节点是指决策树中的非叶子节点,它用来将输入空间划分成若干个子区域。每个划分节点都有一个划分条件,用来决定数据点属于哪个子区域。划分节点不是叶子节点,叶子节点是指决策树中没有子节点的节点,它表示一个单元,用来预测该区域内数据点的目标值。
Q3:单元 = 叶子结点,划分点 = 非叶子结点,对吗?
A3:是的。在决策树中,每个单元都对应着一个叶子节点,每个叶子节点都表示一个单元。划分节点是非叶子节点,它用来将输入空间划分成若干个子区域。
举例:如下图,假如我们想要对楼内居民的年龄进行回归,将楼划分为 3 个区域 R 1 , R 2 , R 3 R_1, R_2, R_3 R1,R2,R3(红线)。那么 R 1 R_1 R1 的输出就是第一列四个居民年龄的平均值, R 2 R_2 R2 的输出就是第二列四个居民年龄的平均值, R 3 R_3 R3 的输出就是第三、四列八个居民年龄的平均值。
7.2 算法描述
输入:训练数据集
D
D
D
输出:回归树
f
(
x
)
f(x)
f(x)
在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:
一、选择最优切分特征 j j j 与切分点 s s s,求解
min j , s [ min c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \underset{j, s}{\min}\left[ \underset{c_1}{\min} \sum_{x_i \in R_1(j, s)} (y_i - c_1)^2 + \underset{c_2}{\min} \sum_{x_i \in R_2(j, s)}(y_i - c_2)^2 \right] j,smin c1minxi∈R1(j,s)∑(yi−c1)2+c2minxi∈R2(j,s)∑(yi−c2)2
遍历特征 j j j,对固定的切分特征 j j j 扫描切分点 s s s,选择使得上式达到最小值的对 ( j , s ) (j, s) (j,s)
二、用选定的对 ( j , s ) (j, s) (j,s) 划分区域并决定相应的输出值:
R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_1(j, s) = x|x^{(j)} \le s R1(j,s)=x∣x(j)≤s
R 2 ( j , s ) = x ∣ x ( j ) > s R_2(j, s) = x|x^{(j)} >s R2(j,s)=x∣x(j)>s
c ^ m = 1 N ∑ x 1 ∈ R m ( j , s ) y i 其中 x ∈ R m , m = 1 , 2 \hat{c}_m = \frac{1}{N}\sum_{x_1 \in R_m(j, s)} y_i \ \ 其中x \in R_m, m = 1, 2 c^m=N1x1∈Rm(j,s)∑yi 其中x∈Rm,m=1,2
三、继续对两个子区域调用步骤一和二,直至满足停止条件。
四、将输入空间划分为 M M M 个区域 R 1 , R 2 , . . . , R M R_1, R_2, ..., R_M R1,R2,...,RM,生成决策树:
f ( x ) = ∑ m = 1 M c ^ m I ( x ∈ R m ) f(x) = \sum_{m = 1}^M \hat{c}_mI(x\in R_m) f(x)=m=1∑Mc^mI(x∈Rm)
其中:
- D D D 表示训练数据集
- f ( x ) f(x) f(x) 表示回归树
- j j j 和 s s s 分别表示最优切分特征和切分点
- R 1 ( j , s ) R_1(j, s) R1(j,s) 和 R 2 ( j , s ) R_2(j, s) R2(j,s) 分别表示根据最优切分特征和切分点划分出的两个子区域
- c 1 c_1 c1 和 c 2 c_2 c2 分别表示两个子区域内的输出值
- c ^ m \hat{c}_m c^m 表示第 m m m 个区域内的输出值,它等于该区域内所有点 y y y 值的平均数
- M M M 表示输入空间被划分成的区域数
- R 1 , R 2 , . . . , R M R_1, R_2, ..., R_M R1,R2,...,RM 表示每个区域。
7.3 简单实例
为了易于理解,接下来通过一个简单实例加深对回归决策树的理解。训练数据见下表,我们的目标是得到一棵最小二乘回归树。
x x x(特征值) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
y y y(目标值) | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 | 8.9 | 8.7 | 9 | 9.05 |
7.3.1 实例计算过程
一、选择最优的切分特征 j j j 与最优切分点 s s s:
- 确定第一个问题:选择最优切分特征
- 在本数据集中,只有一个特征,因此最优切分特征自然是 x x x
- 确定第二个问题:我们考虑 9 个切分点
[
1.5
,
2.5
,
3.5
,
4.5
,
5.5
,
6.5
,
7.5
,
8.5
,
9.5
]
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]
[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5]。
- 损失函数定义为平方损失函数 L ( y , f ( x ) ) = [ f ( x ) − y ] 2 \mathcal{L}(y, f(x)) = [f(x) - y]^2 L(y,f(x))=[f(x)−y]2,其中 f ( x ) f(x) f(x) 为预测值, y y y 为真实值(目标值)
- 将上述 9 个切分点依此代入下面的公式,其中 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i | x_i \in R_m) cm=avg(yi∣xi∈Rm)
a. 计算子区域输出值:
当切分点 s = 1.5 s=1.5 s=1.5 时,数据被分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。 R 1 R_1 R1 包括特征值为 1 1 1 的数据点,而 R 2 R_2 R2 包括特征值为 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 2,3,4,5,6,7,8,9,10 2,3,4,5,6,7,8,9,10 的数据点。
c 1 c_1 c1 和 c 2 c_2 c2 分别是这两个子区域的输出值。它们的计算方法是将各自子区域内的目标值相加,然后除以数据点的数量。因此,这两个区域的输出值分别为:
- c 1 = 5.56 c_1 = 5.56 c1=5.56
- c 2 = 5.7 + 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 9 = 7.50 c_2= \frac{5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05}{9} = 7.50 c2=95.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05=7.50
当切分点 s = 2.5 s=2.5 s=2.5 时,数据被分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。 R 1 R_1 R1 包括特征值为 1 , 2 1,2 1,2 的数据点,而 R 2 R_2 R2 包括特征值为 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 3,4,5,6,7,8,9,10 3,4,5,6,7,8,9,10 的数据点。
c 1 c_1 c1 和 c 2 c_2 c2 分别是这两个子区域的输出值。它们的计算方法是将各自子区域内的目标值相加,然后除以数据点的数量。因此,这两个区域的输出值分别为:
-
c 1 = 5.56 + 5.7 2 = 5.63 c_1 = \frac{5.56 + 5.7}{2} = 5.63 c1=25.56+5.7=5.63
-
c 2 = 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 8 = 7.73 c_2 = \frac{5.91+6.4+6.8+7.05+8.9+8.7+9+9.05}{8} = 7.73 c2=85.91+6.4+6.8+7.05+8.9+8.7+9+9.05=7.73
同理,我们可以得到其他各切分点的子区域输出值,如下表所示:
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
c 1 c1 c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 | 6.24 | 6.62 | 6.88 | 7.11 |
c 2 c2 c2 | 7.5 | 7.73 | 7.99 | 8.25 | 8.54 | 8.91 | 8.92 | 9.03 | 9.05 |
b. 计算损失函数值,找到最优切分点:
把 c 1 c_1 c1, c 2 c_2 c2 的值代入到平方损失函数 L ( y , f ( x ) ) = [ f ( x ) − y ] 2 \mathcal{L}(y, f(x)) = [f(x) - y]^2 L(y,f(x))=[f(x)−y]2,其中 f ( x ) f(x) f(x) 为预测值, y y y 为真实值(目标值)
当s=1.5时:总损失为:
L = ∑ x i ∈ R 1 [ f ( x i ) − y i ] 2 + ∑ x i ∈ R 2 [ f ( x i ) − y i ] 2 = [ 5.56 − 5.56 ] 2 + [ 7.50 − 5.7 ] 2 + [ 7.50 − 5.91 ] 2 + . . . + [ 7.50 − 9.05 ] 2 = 0 + ( 1.8 ) 2 + ( 1.59 ) 2 + . . . + ( − 1.55 ) 2 = 15.72 \begin{aligned} \mathcal{L} &= \sum_{x_i \in R_1} [f(x_i) - y_i]^2 + \sum_{x_i \in R_2} [f(x_i) - y_i]^2 \\ &= [5.56 - 5.56]^2 + [7.50 - 5.7]^2 + [7.50 - 5.91]^2 + ... + [7.50 - 9.05]^2 \\ &= 0 + (1.8)^2 + (1.59)^2 + ... + (-1.55)^2 & = 15.72 \end{aligned} L=xi∈R1∑[f(xi)−yi]2+xi∈R2∑[f(xi)−yi]2=[5.56−5.56]2+[7.50−5.7]2+[7.50−5.91]2+...+[7.50−9.05]2=0+(1.8)2+(1.59)2+...+(−1.55)2=15.72
当切分点 s = 2.5 s=2.5 s=2.5 时,总损失为:
L = ∑ x i ∈ R 1 [ f ( x i ) − y i ] 2 + ∑ x i ∈ R 2 [ f ( x i ) − y i ] 2 = [ 5.63 − 5.56 ] 2 + [ 5.63 − 5.7 ] 2 + [ 7.73 − 5.91 ] 2 + . . . + [ 7.73 − 9.05 ] 2 = ( 0.07 ) 2 + ( − 0.07 ) 2 + ( 1.82 ) 2 + . . . + ( − 1.32 ) 2 \begin{aligned} \mathcal{L} &= \sum_{x_i \in R_1} [f(x_i) - y_i]^2 + \sum_{x_i \in R_2} [f(x_i) - y_i]^2 \\ &= [5.63 - 5.56]^2 + [5.63 - 5.7]^2 + [7.73 - 5.91]^2 + ... + [7.73 - 9.05]^2 \\ &= (0.07)^2 + (-0.07)^2 + (1.82)^2 + ... + (-1.32)^2 \end{aligned} L=xi∈R1∑[f(xi)−yi]2+xi∈R2∑[f(xi)−yi]2=[5.63−5.56]2+[5.63−5.7]2+[7.73−5.91]2+...+[7.73−9.05]2=(0.07)2+(−0.07)2+(1.82)2+...+(−1.32)2
同理,计算得到其他各切分点的损失函数值,可获得下表:
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
m ( s ) m(s) m(s) | 15.72 | 12.07 | 8.36 | 5.78 | 3.91 | 1.93 | 8.01 | 11.73 | 15.74 |
显然取 s = 6.5 s=6.5 s=6.5 时, m ( s ) m(s) m(s) 最小。因此第一个划分变量 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5]
Q:为什么要用
m
(
s
)
m(s)
m(s),不应该是
L
(
y
,
f
(
x
)
)
\mathcal{L}(y, f(x))
L(y,f(x))吗?
A:
m
(
s
)
m(s)
m(s) 和
L
\mathcal{L}
L 都表示损失函数。在回归决策树中,损失函数用于衡量划分后的子区域内预测值与真实值之间的差异。不同的文献或资料可能会使用不同的符号来表示损失函数,但它们的意义是相同的。
m ( s ) m(s) m(s) 用于表示在切分点 s s s 处的损失函数值。因此,当计算不同切分点处的损失函数值时,使用 m ( s ) m(s) m(s) 或 L \mathcal{L} L 都是可以的。
二、用选定的 ( j , s ) (j, s) (j,s) 划分区域,并决定输出值:
- 两个区域分别是: R 1 = 1 , 2 , 3 , 4 , 5 , 6 R_1={1,2,3,4,5,6} R1=1,2,3,4,5,6, R 2 = 7 , 8 , 9 , 10 R_2={7,8,9,10} R2=7,8,9,10
- 输出值 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i|x_i\in R_m) cm=avg(yi∣xi∈Rm), c 1 = 6.24 c_1 =6.24 c1=6.24, c 2 = 8.91 c_2 = 8.91 c2=8.91
三、调用步骤一、二,继续划分:
对 R 1 R_1 R1 继续进行划分:
x x x(特征值) | 1 | 2 | 3 | 4 | 5 | 6 |
---|---|---|---|---|---|---|
y y y(目标值) | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 |
取切分点 [ 1.5 , 2.5 , 3.5 , 4.5 , 5.5 ] [1.5,2.5,3.5,4.5,5.5] [1.5,2.5,3.5,4.5,5.5],则各区域的输出值 c c c 如下表:
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
c 1 c1 c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 |
c 2 c2 c2 | 6.37 | 6.54 | 6.75 | 6.93 | 7.02 |
计算损失函数值 m ( s ) m(s) m(s):
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
m ( s ) m(s) m(s) | 1.3087 | 0.754 | 0.2771 | 0.4368 | 1.0644 |
s = 3.5 s=3.5 s=3.5 时, m ( s ) m(s) m(s) 最小。
循环…
回归决策树的划分终止条件通常有以下几种:
- 子区域中的数据点数量小于预先设定的阈值。
- 子区域中的数据点目标值的方差小于预先设定的阈值。
- 树的深度达到预先设定的最大深度。
当满足上述任意一个条件时,划分过程将终止。这些条件可以根据具体问题进行调整,以获得最佳的模型性能。
四、生成回归树
假设在生成 3 个区域之后停止划分,那么最终生成的回归树形式如下:
T = { 5.72 x ≤ 3.5 6.75 3.5 ≤ x ≤ 6.5 8.91 6.5 < x T = \begin{cases} 5.72 & x \le 3.5 \\ 6.75 & 3.5 \le x \le 6.5 \\ 8.91 & 6.5 < x \end{cases} T=⎩ ⎨ ⎧5.726.758.91x≤3.53.5≤x≤6.56.5<x
这棵回归树的结构如下:
[j=x,s=6.5]
/ \
[j=x,s=3.5] R_2
/ \
R_{11} R_{12}
其中, R 11 R_{11} R11、 R 12 R_{12} R12 和 R 2 R_2 R2 都是叶子节点。
这棵回归树有三个叶子节点,分别对应三个子区域 R 11 R_{11} R11、 R 12 R_{12} R12 和 R 2 R_2 R2。根节点的划分变量为 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5],它将数据分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。根节点的左子节点对应子区域 R 1 R_1 R1,它的划分变量为 [ j = x , s = 3.5 ] [j=x,s=3.5] [j=x,s=3.5],将子区域 R 1 R_1 R1 再次分为两个子区域: R 11 R_{11} R11 和 R 12 R_{12} R12。根节点的左子节点的左右子节点分别对应子区域 R 11 R_{11} R11 和 R 12 R_{12} R12,它们都是叶子节点。根节点的右子节点对应子区域 R 2 R_2 R2,它也是一个叶子节点。
其中:
- j j j 和 s s s 分别表示切分特征和切分点
-
j
=
x
j=x
j=x 表示切分特征为
x
x
x,而
s
=
6.5
s=6.5
s=6.5 表示切分点为
6.5
6.5
6.5
- 当切分变量为 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5] 时,数据将根据特征 x x x 的值被分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。子区域 R 1 R_1 R1 包括特征值小于等于 6.5 6.5 6.5 的数据点,而子区域 R 2 R_2 R2 包括特征值大于 6.5 6.5 6.5 的数据点。
- 因此,当切分变量为 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5] 时,数据将根据特征 x x x 的值被分为两个子区域。
小结:
- 输入:训练数据集 D D D
- 输出:回归树 f ( x ) f(x) f(x)
- 流程:在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:
- 选择最优切分特征 j j j 与切分点 s s s,求解 min j , s [ min c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \underset{j, s}{\min}\left[ \underset{c_1}{\min} \sum_{x_i \in R_1(j, s)} (y_i - c_1)^2 + \underset{c_2}{\min} \sum_{x_i \in R_2(j, s)}(y_i - c_2)^2 \right] j,smin[c1min∑xi∈R1(j,s)(yi−c1)2+c2min∑xi∈R2(j,s)(yi−c2)2] —— 遍历特征 j j j,对固定的切分特征 j j j 扫描切分点 s s s,选择使得上式达到最小值的对 ( j , s ) (j, s) (j,s)
- 用选定的对
(
j
,
s
)
(j, s)
(j,s) 划分区域并决定相应的输出值:
R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_1(j, s) = x|x^{(j)} \le s R1(j,s)=x∣x(j)≤s
R 2 ( j , s ) = x ∣ x ( j ) > s R_2(j, s) = x|x^{(j)} >s R2(j,s)=x∣x(j)>s
c
^
m
=
1
N
∑
x
1
∈
R
m
(
j
,
s
)
y
i
其中
x
∈
R
m
,
m
=
1
,
2
\hat{c}_m = \frac{1}{N}\sum_{x_1 \in R_m(j, s)} y_i \ \ 其中x \in R_m, m = 1, 2
c^m=N1x1∈Rm(j,s)∑yi 其中x∈Rm,m=1,2
3. 继续对两个子区域调用步骤一和二,直至满足停止条件。
4. 将输入空间划分为
M
M
M 个区域
R
1
,
R
2
,
…
,
R
M
R_1, R_2 , …, R_M
R1,R2,…,RM,生成决策树
f
(
x
)
=
∑
m
=
1
M
c
^
m
I
(
x
∈
R
m
)
f(x) = \sum_{m = 1}^M \hat{c}_mI(x\in R_m)
f(x)=∑m=1Mc^mI(x∈Rm):
7.4 回归决策树和线性回归对比
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
from pylab import mpl
# 设置中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
# 设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False
# 1. ⽣成数据
x = np.array(list(range(1, 11))).reshape(-1, 1) # 使其变为列向量
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05])
# 2. 训练模型
model_1 = DecisionTreeRegressor(max_depth=1) # 决策树模型
model_2 = DecisionTreeRegressor(max_depth=3) # 决策树模型
model_3 = LinearRegression() # 线性回归模型
model_1.fit(x, y)
model_2.fit(x, y)
model_3.fit(x, y)
# 3. 模型预测
X_test = np.arange(0.0, 10.0, 0.01).reshape(-1, 1) # ⽣成1000个数,⽤于预测模型
predict_1 = model_1.predict(X_test)
predict_2 = model_2.predict(X_test)
predict_3 = model_3.predict(X_test)
# 4. 结果可视化
plt.figure(dpi=300)
plt.scatter(x, y, label="原始数据(目标值)")
plt.plot(X_test, predict_1, label="回归决策树: max_depth=1")
plt.plot(X_test, predict_2, label="回归决策树: max_depth=3")
plt.plot(X_test, predict_3, label="线性回归")
plt.xlabel("数据")
plt.ylabel("预测值")
plt.title("线性回归与回归决策树效果对比")
plt.grid(alpha=0.5)
plt.legend()
plt.show()
结果:
8. 决策树总结
8.1 优点
- 易于理解和解释。
- 决策树的结构可以可视化,非专家也能很容易理解。
- 数据准备简单。
- 决策树不需要对数据进行复杂的预处理,例如归一化或去除缺失值。
- 能够同时处理数值型和分类数据。
- 不受数据缩放的影响。
- 计算成本相对较低。
这些优点使得决策树在许多领域都得到了广泛应用。
8.2 缺点
- 容易过拟合。决策树模型可能会产生过于复杂的模型,导致泛化能力较差。
- 可以通过剪枝、设置叶节点所需的最小样本数或设置树的最大深度来避免过拟合。
- 不稳定性。微小的数据变化可能会导致生成完全不同的树。
- 这个问题可以通过决策树集成来缓解。
- 对连续性字段预测困难。
- 当类别太多时,错误率可能会增加较快。
这些缺点需要在使用决策树时予以注意。
8.3 改进的方法
针对决策树的缺点,有一些改进方法可以使用。例如:
- 避免过拟合。
- 可以通过剪枝、设置叶节点所需的最小样本数或设置树的最大深度来避免过拟合。
- 剪枝包括预剪枝和后剪枝。
- 前者通过对连续型变量设置阈值,来控制树的深度,或者控制节点的个数,在节点开始划分之前就进行操作,进而防止过拟合现象。
- 后者是自底向上对非叶节点进行考察,如果这个内部节点换成叶节点能提升决策树的泛化能力,那就把它换掉。
- 使用决策树集成。
- 可以通过集成多个决策树来提高模型的稳定性和准确性。
- 例如,随机森林算法就是基于决策树的集成学习算法,它通过构建多棵决策树并结合它们的预测结果来提高模型的准确性和稳定性。
- 可以通过集成多个决策树来提高模型的稳定性和准确性。
- 对连续性字段进行离散化处理。
- 可以将连续性字段离散化为分类变量,以便决策树能够更好地处理。
- 对类别不平衡的数据进行重采样。
- 可以对类别不平衡的数据进行重采样,以减少错误率。
这些方法可以帮助改进决策树模型,提高其准确性和稳定性。