机器学习决策树

一、简述

1.1概念

        决策树是一种通过构建树状模型对数据进行分类和决策的监督学习算法,它通过特征选择和树的构建来实现预测和决策的过程,用于解决分类和回归问题。

        决策树由节点(node)和边(edge)组成。每个内部节点表示一个特征或属性,用于对数据进行划分,而叶子节点表示最终的决策结果。决策树的根节点是最顶层的特征,而每个内部节点的子节点对应于该特征的可能取值。通过沿着树从根节点到叶子节点的路径,可以根据特征的取值来进行决策或分类。

2.2优缺点

        优点:计算复杂度较低。在生成决策树时,只需对每个特征进行有限次数的比较操作,因此计算复杂度相对较低,适用于处理大规模数据集。

        缺点:容易过拟合。决策树倾向于过拟合训练数据,即在训练集上表现良好但在测试集上表现较差。过拟合问题可以通过剪枝等技术来缓解。

二、构造

2.1决策树构建的一般流程

  1. 特征选择:从训练数据的特征中选择一个特征作为当前节点的分裂标准(特征选择的标准不同产生了不同的特征决策树算法)。

  2. 决策树生成:根据所选特征评估标准,从上至下递归地生成子节点,直到数据集不可分则停止决策树停止声场。

  3. 决策树剪枝:决策树容易过拟合,需要剪枝来缩小树的结构和规模(包括预剪枝和后剪枝)。

2.2信息增益

        信息增益是一种用于特征选择的指标,它表示一个特征对于模型预测目标变量的贡献程度。信息增益越大,说明该特征在分类中的作用越大,应优先选择。

        信息增益的计算基于信息熵的概念。信息熵是度量样本集合纯度的指标,定义为:

其中,pi​ 表示分类 i 在样本集合中出现的频率。

        信息增益越大,表示使用该特征进行划分后可以更好地区分样本,因此应选择信息增益最大的特征作为节点进行划分。

2.3基尼值和基尼指数   

        基尼值(Gini value)和基尼指数(Gini index)是衡量分类问题中不纯度或混乱程度的指标,常用于决策树算法中的特征选择。

        对于一个包含多个类别的样本集合,基尼值表示从该样本集合中随机选择两个样本,其类别标签不一致的概率。基尼值越大,表示样本集合越杂乱。对于一个二分类问题,基尼值可以通过以下公式计算:

        

        其中,pi​ 表示第 i 个类别在样本集合中的比例。

        基尼指数是在决策树算法中用于特征选择的指标,它表示使用某个特征进行划分后,样本集合的基尼值的减少量。基尼指数越小,表示使用该特征进行划分后样本的不纯度减少得越多,即该特征具有更好的划分能力。对于一个特征 A,其基尼指数可以通过以下公式计算:

        

        其中,V 表示特征 A 的取值个数,∣Sv​∣ 表示特征 A 取值为 v 的样本数,Gini(Sv​) 表示特征 A 取值为 v 的子集的基尼值。

        在决策树的特征选择过程中,通常选择基尼指数最小的特征作为节点进行划分。这意味着该特征能够带来最大的信息增益或最大的不纯度减少。

2.4剪枝处理      

        剪枝是一种用于降低决策树模型复杂度、避免过拟合的技术,其目的是移除一些不必要的分支或叶子节点,以提高模型的泛化能力。

        剪枝处理通常分为预剪枝后剪枝两种方式。预剪枝是在构造决策树时,在每个节点处对应用于划分的特征进行评估,如果划分后不能显著提高模型性能,就停止该节点的进一步划分,将该节点转化为叶子节点。预剪枝可以有效地减少决策树的大小和复杂度,但它需要事先确定一个合适的停止条件,因此可能会导致欠拟合。后剪枝是在决策树构建完成后,对决策树进行修剪。具体而言,它通过移除某些子树或叶子节点来改善决策树的泛化能力。后剪枝通常采用验证集的方法来确定哪些子树应该被剪掉,具体步骤如下:

  1. 将原始数据集划分为训练集和验证集;
  2. 使用训练集构建决策树;
  3. 对每个非叶子节点,尝试将其替换为叶子节点,观察此时在验证集上的性能变化,并记录最优性能;
  4. 重复步骤3,直到对所有非叶子节点都进行了尝试;
  5. 根据记录的最优性能修剪决策树。

        后剪枝可以有效地提高决策树的泛化能力,但需要一定的计算时间和额外的数据集用于验证。

        无论是预剪枝还是后剪枝,剪枝处理都可以防止决策树过拟合,提高模型的泛化性能。

三、实现

import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 构建决策树模型
clf = DecisionTreeClassifier()

# 拟合决策树模型
clf.fit(X, y)

# 绘制决策树
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, rounded=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.show()

测试结果:

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值