《机器学习实战:基于Scikit-Learn、Keras和TensorFlow第2版》-学习笔记(6):决策树

本文详细介绍了决策树在机器学习中的应用,包括训练流程、可视化方法、预测机制、CART算法、计算复杂度、基尼不纯度和熵的比较、正则化超参数以及回归任务。决策树因其直观性和解释性而被视为白盒模型,但也存在过拟合和不稳定性问题。通过调整超参数和正则化,可以改善模型性能。
摘要由CSDN通过智能技术生成

· Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition, by Aurélien Géron (O’Reilly). Copyright 2019 Aurélien Géron, 978-1-492-03264-9.
· 环境:Anaconda(Python 3.8) + Pycharm
· 学习时间:2022.05.04~2022.05.05

第六章 决策树

与SVM一样,决策树是通用的机器学习算法,可以执行分类和回归任务,甚至多输出任务。它们是功能强大的算法,能够拟合复杂的数据集。例如,在第2章中,你在加州房屋数据集中训练了DecisionTreeRegressor模型,使其完全拟合(实际上是过拟合)。

决策树也是随机森林的基本组成部分(见第7章),它们是当今最强大的机器学习算法之一。

在本章中,我们将从讨论如何使用决策树进行训练、可视化和做出预测开始。然后,我们将了解Scikit-Learn使用的CART训练算法,并将讨论如何对树进行正则化并将其用于回归任务。最后,我们将讨论决策树的一些局限性。

6.1 训练和可视化决策树

为了理解决策树,让我们建立一个决策树,然后看看它是如何做出预测的。以下代码在鸢尾花数据集上训练了一个DecisionTreeClassifier:

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()
X = iris.data[:, 2:]  # petal length and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X, y)

要将决策树可视化,首先,使用export_graphviz()方法输出一个图形定义文件,命名为iris_tree.dot:

import os
from sklearn.tree import export_graphviz

export_graphviz(
    tree_clf,
    out_file=os.path.join("iris_tree.dot"),
    feature_names=iris.feature_names[2:],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)

然后,你可以使用Graphviz软件包中的dot命令行工具将此.dot文件转换为多种格式,例如PDF或PNG。此命令行将.dot文件转换为.png图像文件:

$ dot -Tpng iris_tree.dot -o iris_tree.png

PyCharm安装一个.dot的插件就好了。

绘制如图:

在这里插入图片描述

6.2 做出预测

让我们看看上图中的树是如何进行预测的。假设你找到一朵鸢尾花,要对其进行分类。你从根节点开始(深度为0,在顶部):该节点询问花的花瓣长度是否小于2.45cm。如果是,则向下移动到根的左子节点(深度1,左)。在这种情况下,它是一片叶子节点(即它没有任何子节点),因此它不会提出任何问题:只需查看该节点的预测类,然后决策树就可以预测花朵是山鸢尾花(class=setosa)。

现在假设你发现了另一朵花,这次花瓣的长度大于2.45cm,你必须向下移动到根的右子节点(深度1,右),该子节点不是叶子节点,因此该节点会问另一个问题:花瓣宽度是否小于1.75cm?如果是,则你的花朵很可能是变色鸢尾花(深度2,左)。如果不是,则可能是维吉尼亚鸢尾花(深度2,右)。就是这么简单。

决策树的许多特质之一就是它们几乎不需要数据准备。实际上,它们根本不需要特征缩放或居中

节点的samples属性统计它应用的训练实例数量。例如,有100个训练实例的花瓣长度大于2.45cm(深度1,右),其中54个花瓣宽度小于1.75cm(深度2,左)。

节点的value属性说明了该节点上每个类别的训练实例数量。例如,右下节点应用在0个山鸢尾、1个变色鸢尾和45个维吉尼亚鸢尾实例上。

最后,节点的gini属性衡量其不纯度(impurity):如果应用的所有训练实例都属于同一个类别,那么节点就是“纯”的(gini=0)。例如,深度1左侧节点仅应用于山鸢尾花训练实例,所以它就是纯的,并且gini值为0。公式6-1说明了第i个节点的基尼系数Gi的计算方式。例如,深度2左侧节点,基尼系数等于 1 – ( 0 / 54 ) 2 – ( 49 / 54 ) 2 – ( 5 / 54 ) 2 ≈ 0.168 1–(0/54)^2–(49/54)^2–(5/54)^2≈0.168 1(0/54)2(49/54)2(5/54)20.168。(基尼不纯度: G i = 1 − ∑ k = 1 n p i , k 2 G_i = 1-\sum^n_{k=1}p_{i,k}^2 Gi=1

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

新四石路打卤面

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值