机器学习笔记--决策树

1、决策树原理

(1)分类决策树模型是表示基于特征对实例进行分类的树形结构。决策树可以转换成一个if-then规则的集合,也可以看作是定义在特征空间划分上的类的条件概率分布。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FjWu9boa-1691893391155)(https://note.youdao.com/yws/res/3101/73C2265C3D5F4D1AA8E82F513D527A62)]

(2)决策树学习旨在构建一个与训练数据拟合很好,并且复杂度小的决策树。因为从可能的决策树中直接选取最优决策树是NP完全问题。现实中采用启发式方法学习次优的决策树。

决策树学习算法包括3部分:特征选择、树的生成和树的剪枝。常用的算法有ID3、C4.5和CART。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DHYlP4YB-1691893391156)(https://note.youdao.com/yws/res/3105/4A00D4366CC645688EF0AE30C0ECF823)]

(3)特征选择的目的在于选取对训练数据能够分类的特征。特征选择的关键是其准则。常用的准则如下:

1)样本集合 D D D 对特征 A A A 的信息增益(ID3)

g ( D , A ) = H ( D ) − H ( D ∣ A ) g(D, A)=H(D)-H(D|A) g(D,A)=H(D)H(DA)
H ( D ) = − ∑ k = 1 K ∣ C k ∣ ∣ D ∣ log ⁡ 2 ∣ C k ∣ ∣ D ∣ H(D)=-\sum_{k=1}^{K} \frac{\left|C_{k}\right|}{|D|} \log _{2} \frac{\left|C_{k}\right|}{|D|} H(D)=k=1KDCklog2DCk
H ( D ∣ A ) = ∑ i = 1 n ∣ D i ∣ ∣ D ∣ H ( D i ) H(D | A)=\sum_{i=1}^{n} \frac{\left|D_{i}\right|}{|D|} H\left(D_{i}\right) H(DA)=i=1nDDiH(Di)

其中, H ( D ) H(D) H(D) 是数据集 D D D 的熵, H ( D i ) H(D_i) H(Di) 是数据集 D i D_i Di 的熵, H ( D ∣ A ) H(D|A) H(DA) 是数据集 D D D 对特征 A A A 的条件熵。 D i D_i Di D D D 中特征 A A A 取第 i i i 个值的样本子集, C k C_k Ck D D D 中属于第 k k k 类的样本子集。 n n n 是特征 A A A 取值的个数, K K K 是类的个数。

2)样本集合 D D D 对特征 A A A 的信息增益比(C4.5)

g R ( D , A ) = g ( D , A ) H ( D ) g_{R}(D, A)=\frac{g(D, A)}{H(D)} gR(D,A)=H(D)g(D,A)

其中, g ( D , A ) g(D,A) g(D,A) 是信息增益, H ( D ) H(D) H(D) 是数据集 D D D 的熵。

3)样本集合 D D D 的基尼指数(CART)

Gini ⁡ ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 \operatorname{Gini}(D)=1-\sum_{k=1}^{K}\left(\frac{\left|C_{k}\right|}{|D|}\right)^{2} Gini(D)=1k=1K(DCk)2

特征 A A A 条件下集合 D D D 的基尼指数:

Gini ⁡ ( D , A ) = ∣ D 1 ∣ ∣ D ∣ Gini ⁡ ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ Gini ⁡ ( D 2 ) \operatorname{Gini}(D, A)=\frac{\left|D_{1}\right|}{|D|} \operatorname{Gini}\left(D_{1}\right)+\frac{\left|D_{2}\right|}{|D|} \operatorname{Gini}\left(D_{2}\right) Gini(D,A)=DD1Gini(D1)+DD2Gini(D2)

(4)决策树的生成。通常使用信息增益最大、信息增益比最大或基尼指数最小作为特征选择的准则。决策树的生成往往通过计算信息增益或其他指标,从根结点开始,递归地产生决策树。这相当于用信息增益或其他准则不断地选取局部最优的特征,或将训练集分割为能够基本正确分类的子集。

(5)决策树的剪枝。由于生成的决策树存在过拟合问题,需要对它进行剪枝,以简化学到的决策树。决策树的剪枝,往往从已生成的树上剪掉一些叶结点或叶结点以上的子树,并将其父结点或根结点作为新的叶结点,从而简化生成的决策树。

2、实例:决策树分类和回归

# @Time : 2021/12/10 11:15
# @Author : xiao cong
# @Function : 鸢尾花数据集实现决策树分类

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from collections import Counter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


def create_data():
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=['sepal length', 'sepal width', 'petal length', 'petal width'])
    df["label"] = iris.target
    data = np.array(df.iloc[:100, [0, 1, -1]])
    return data[:, :2], data[:, -1]


X, y = create_data()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

# *****************************************************************************************************
"""
决策树分类
"""

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from sklearn import tree

clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))
tree.plot_tree(clf)                      # 导出树
plt.show()


# **********************************************************************************************
"""
决策树回归
"""
from sklearn.tree import DecisionTreeRegressor

rng = np.random.RandomState(1)             # 随机数种子
X = np.sort(5 * rng.rand(80, 1), axis=0)           # 维度(80,1)
y = np.sin(X).flatten()
y[::5] += 3 * (0.5 - rng.rand(16))                # 每间隔5 ,加上随机噪声

reg1 = DecisionTreeRegressor(max_depth=2)
reg2 = DecisionTreeRegressor(max_depth=5)           # 指 树的最大深度
reg1.fit(X, y)
reg2.fit(X, y)

# 预测
X_test = np.arange(0, 5, 0.01).reshape(-1, 1)              # 0~5之间每隔0.01生成一个数据。增加一个维度
y1 = reg1.predict(X_test)
y2 = reg2.predict(X_test)

plt.figure()
plt.scatter(X, y, s=20, edgecolors='black', c="darkorange", label="data")
plt.plot(X_test, y1, color="cornflowerblue", label="max_depth=2", linewidth=2)
plt.plot(X_test, y2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UTWqk23c-1691893391157)(https://note.youdao.com/yws/res/3114/EB49D203BD08458EBC70F979D387226B)]

在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值