sklearn学习04——DecisionTree


前言

本篇简单回顾一下决策树原理,然后采用sklearn的鸢尾花数据集实现决策树模型的训练和预测。

一、决策树原理

这里只说明决策树的基本算法流程和最优划分属性的选择(核心思想),具体细节可以参考我的另一篇文章:机器学习第四章—决策树

1.1、算法基本流程

学习算法
显然,决策树的生成过程是一个递归过程。这里,分支完成的标志有三种情况,也是递归返回的三种情况:

  • 当前结点包含的样本全属于同一类别,无需划分; (即已经分好类的情况)
  • 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分;(所有属性都被判断过,或者每个样本特征都相同导致无法划分)
  • 当前结点包含的样本集合为空,不能划分。(当前结点已经没有样本,不需要再划分)

处理方式:

  • 在第(1)种情形下,直接把当前结点标记为叶结点,类别也就是当前所有样本所属的这个类别;
  • 在第(2)种情形下,我们把当前结点标记为叶结点,并将其类别设定为该结点所含样本最多的类别;(所有属性都判断完了还没分好,这就看该结点下的这些样本哪一类多了)
  • 在第(3)种情形下,同样把当前结点标记为叶结点,但将其类别设定为其父结点所含样本最多的类别(也就是这种特征的样本,训练集中没有,所以要靠父结点来帮助,其实就是决策树能力不够,使得判断条件松了一级)。

1.2、最优划分属性的选择

决策树的生成过程中,关键一步是:最优划分属性的生成。即找到能将当前数据集划分最好的属性(划分之后的几个集合纯度越高,划分的越好)。关于选择最划分优属性的标准,一般有以下三种:

  • 信息增益。代表实现: ID3 决策树
  • 增益率。代表实现:C4.5 决策树
  • 基尼指数。代表实现: CART 决策树

这三种标准的具体数学形式不再展开,详见上述链接。下面使用sklearn所实现的一种决策树就是按照基尼指数的标准进行划分、生成决策树的。

二、sklearn代码实践

2.1、引入库

这里使用 iris (鸢尾花)数据集。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。

代码如下:

import seaborn as sns
from pandas import plotting
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import tree

2.2、查看数据集信息

打印输出鸢尾花数据集特征、种类(样本标记)的信息,顺带查看前几条样本。

代码如下:

# 加载数据集
data = load_iris() 
# 转换成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品种列
df['Species'] = data.target
# 查看数据集信息
print(f"数据集信息:\n{df.info()}")
# 查看前5条数据
print(f"前5条数据:\n{df.head()}")
# 查看各特征列的摘要信息
df.describe()

2.3、使用可视化类直观分析各个特征与品种的关系

代码如下:

# 设置颜色主题
antV = ['#1890FF', '#2FC25B', '#FACC14', '#223273', '#8543E0', '#13C2C2', '#3436c7', '#F04864'] 
# 绘制violinplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True) # 删除上方和右方坐标轴上不需要的边框,这在matplotlib中是无法通过参数实现的
sns.violinplot(x='Species', y=df.columns[0], data=df, palette=antV, ax=axes[0, 0])
sns.violinplot(x='Species', y=df.columns[1], data=df, palette=antV, ax=axes[0, 1])
sns.violinplot(x='Species', y=df.columns[2], data=df, palette=antV, ax=axes[1, 0])
sns.violinplot(x='Species', y=df.columns[3], data=df, palette=antV, ax=axes[1, 1])
plt.show()
# 绘制pointplot
f, axes = plt.subplots(2, 2, figsize=(8, 6), sharex=True)
sns.despine(left=True)
sns.pointplot(x='Species', y=df.columns[0], data=df, color=antV[1], ax=axes[0, 0])
sns.pointplot(x='Species', y=df.columns[1], data=df, color=antV[1], ax=axes[0, 1])
sns.pointplot(x='Species', y=df.columns[2], data=df, color=antV[1], ax=axes[1, 0])
sns.pointplot(x='Species', y=df.columns[3], data=df, color=antV[1], ax=axes[1, 1])
plt.show()
# g = sns.pairplot(data=df, palette=antV, hue= 'Species')
# 安德鲁曲线
plt.subplots(figsize = (8,6))
plotting.andrews_curves(df, 'Species', colormap='cool')

plt.show()

得到以下的鸢尾花四个特征与品种的关系图:
关系图1
关系图2
关系图3

2.4、训练决策树(基于Gini值)

代码如下:

# 加载数据集
data = load_iris() 
# 转换成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品种列
df['Species'] = data.target

# 用数值替代品种名作为标签
target = np.unique(data.target)
target_names = np.unique(data.target_names)
targets = dict(zip(target, target_names))
df['Species'] = df['Species'].replace(targets)

# 提取数据和标签
X = df.drop(columns="Species")
y = df["Species"]
feature_names = X.columns
labels = y.unique()

X_train, test_x, y_train, test_lab = train_test_split(X,y,
                                                 test_size = 0.4,
                                                 random_state = 42)
model = DecisionTreeClassifier(max_depth =3, random_state = 42)
model.fit(X_train, y_train) 
# 以文字形式输出树     
text_representation = tree.export_text(model)
print(text_representation)
# 用图片画出
plt.figure(figsize=(30,10), facecolor ='g') #
a = tree.plot_tree(model,
                   feature_names = feature_names,
                   class_names = labels,
                   rounded = True,
                   filled = True,
                   fontsize=14)
plt.show()  

最后可以得到一棵如下的决策树模型:

决策树模型
控制台的如下运行结果,代表着决策树划分时属性的先后顺序,这个顺当然对应了上边的决策树模型的各个结点。下图即表示:先判断第 2 个特征的值,根据其值划分出两个分支;> 2.45 的分支会优先选择第 3 个特征,第 3 个特征判断完之后又选择了第 2 个特征(因为特征值是连续的,所以可能会重复出现在多层!)…

在这里插入图片描述

总结

本篇首先介绍决策树的原理,其关键点就是 最优属性的划分,有三个不同的选择标准供我们选择,详细数学形式和背后的含义可以参考西瓜书或者我的另一篇博文(机器学习第四章—决策树);然后在使用sklearn实现一个决策树模型,加强理解决策树的分类思想。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值