【机器学习】长文了解IRIS,一个鸢尾花分类项目

简介:本文将带领你从零开始,完成一个简单的机器学习项目。我们将使用Python编程语言和常用的机器学习库,逐步探索数据准备、模型训练和评估的整个过程。通过这个项目,你将掌握机器学习的基本概念和实践经验,为进一步深入学习打下坚实的基础。

在开始这个机器学习练手项目之前,你需要确保已经安装了Python和相关的机器学习库,如NumPy、Pandas和Scikit-learn。如果你还没有安装这些库,可以通过以下命令进行安装: 

# pip install numpy pandas scikit-learn shap

 其中numpy、pandas包含一些计算方法,scikit-learn包含机器学习的方法例如支持向量机,随机森林等等,shap一般用来进行可解释性分析。

1、数据集介绍

iris数据集的中文名是安德森鸢尾花卉数据集,英文全称是Anderson’s Iris data set。iris包含150个样本,对应数据集的每行数据。每行数据包含每个样本的四个特征和样本的类别信息,所以iris数据集是一个150行5列的二维表。

通俗地说,iris数据集是用来给花做分类的数据集,每个样本包含了花萼长度、花萼宽度、花瓣长度、花瓣宽度四个特征(前4列),我们需要建立一个分类器,分类器可以通过样本的四个特征来判断样本属于山鸢尾、变色鸢尾还是维吉尼亚鸢尾(这三个名词都是花的品种)。

鸢尾花图例

2、数据及特征分析

 2.1、导入所需要的模块并加载数据集
import pandas as pd
from sklearn import datasets
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import metrics
iris = datasets.load_iris()
 2.2、查看数据集的描述
print(iris['DESCR'])
print(iris['feature_names'])

output:
.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, predictive attributes and the class
:Attribute Information:
    - sepal length in cm
    - sepal width in cm
    - petal length in cm
    - petal width in cm
    - class:
            - Iris-Setosa
            - Iris-Versicolour
            - Iris-Virginica

:Summary Statistics:

============== ==== ==== ======= ===== ====================
                Min  Max   Mean    SD   Class Correlation
============== ==== ==== ======= ===== ====================
sepal length:   4.3  7.9   5.84   0.83    0.7826
...
    conceptual clustering system finds 3 classes in the data.
  - Many, many more ...

['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
2.3、将数据转为pandas所支持的DataFrame类型的数据
data=pd.DataFrame(iris.data,columns=iris.feature_names)
label=pd.DataFrame(iris.target,columns=['label'])
data=pd.concat([data,label],axis=1)
data

 2.4、做一下Label到名称的定义
data["class_name"]=data["label"].apply(lambda x: iris.target_names[x])
data

2.5、 数据查看

通过describe方法可以获得特征的数量、平均值、最大值等等信息

data.describe()

2.6、 绘制箱线图

箱线图描述了数据的分布情况,包括:上下界,上下四分位数和中位数,可以简单的查看数据的分布情况。比如:上下四分位数相隔较远的话,一般可以很容易分为2类。

data[iris.feature_names].plot(kind='box', subplots=True, layout=(2, 2), sharex=False, sharey=False)

对data的特征列进行绘制箱线图的操作

2.7、绘制直方图
data[iris.feature_names].hist()

 2.8、使用Plot直接展示数据的分布情况
data[iris.feature_names].plot()

 2.9、KDE核密度估计
data[iris.feature_names].plot(kind='kde')

 2.10、径向特征可视化

径向可视化是多维数据降维的可视化方法,不管是数据分析还是机器学习,降维是最基础的方法之一,通过降维,可以有效的减少复杂度。径向坐标可视化是基于弹簧张力最小化算法。它把数据集的特征映射成二维目标空间单位圆中的一个点,点的位置由系在点上的特征决定。把实例投入圆的中心,特征会朝圆中此实例位置(实例对应的归一化数值)“拉”实例。

cir_name=[*iris.feature_names,'class_name']
ax = pd.plotting.radviz(data[cir_name], 'class_name', colormap='brg')
ax.add_artist(plt.Circle((0,0), 1, color='r', fill = False))

2.11、绘制小提琴图

 

# 设置颜色主题
antV = ['#1890FF', '#2FC25B', '#FACC14', '#223273', '#8543E0', '#13C2C2', '#3436c7', '#F04864'] 

# 绘制  Violinplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True)

sns.violinplot(x='class_name', y='sepal length (cm)', data=data[cir_name], palette=antV, ax=axes[(0, 0)])
sns.violinplot(x='class_name', y='sepal width (cm)', data=data[cir_name], palette=antV, ax=axes[(0, 1)])
sns.violinplot(x='class_name', y='petal length (cm)', data=data[cir_name], palette=antV, ax=axes[(1, 0)])
sns.violinplot(x='class_name', y='petal width (cm)', data=data[cir_name], palette=antV, ax=axes[(1, 1)])

plt.show()

2.12、绘制点图
# 设置颜色主题
antV = ['#1890FF', '#2FC25B', '#FACC14', '#223273', '#8543E0', '#13C2C2', '#3436c7', '#F04864'] 

# 绘制  Violinplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True)

sns.pointplot(x='class_name', y='sepal length (cm)', data=data[cir_name], palette=antV, ax=axes[(0, 0)])
sns.pointplot(x='class_name', y='sepal width (cm)', data=data[cir_name], palette=antV, ax=axes[(0, 1)])
sns.pointplot(x='class_name', y='petal length (cm)', data=data[cir_name], palette=antV, ax=axes[(1, 0)])
sns.pointplot(x='class_name', y='petal width (cm)', data=data[cir_name], palette=antV, ax=axes[(1, 1)])

plt.show()

 

2.13、绘制多变量联合分布图
g = sns.pairplot(data=data[cir_name], palette=antV, hue= 'class_name')

 

 从图中可以看出,利用花瓣和花萼的测量数据基本可以将三个类别区分开。这说明机器学习模型很可能可以学会区分它们。

2.14、Andrews曲线

Andrews曲线将每个样本的属性值转化为傅里叶序列的系数来创建曲线,这对于检测时间序列数据中的异常值很有用。通过将每一类曲线标成不同颜色可以可视化聚类数据,属于相同类别的样本的曲线通常更加接近并构成了更大的结构。

plt.subplots(figsize = (10,8))
pd.plotting.andrews_curves(data[cir_name], 'class_name', colormap='cool')
plt.show()

平行坐标可以看到数据中的类别以及从视觉上估计其他的统计量。使用平行坐标时,每个点用线段联接,每个垂直的线代表一个属性, 一组联接的线段表示一个数据点。可能是一类的数据点会更加接近。

pd.plotting.parallel_coordinates(data[cir_name], 'class_name', colormap='cool')

 通过上图的花瓣长度特征大致可以看到,其中setosa的花瓣长度在区间1-2cm之间,具有很好的分类性,可以通过分类模型进行分类。

2.15、长度和宽度的线性回归分析
g = sns.lmplot(data=data[cir_name], x='sepal width (cm)', y='sepal length (cm)', palette=antV, hue='class_name')
g = sns.lmplot(data=data[cir_name], x='petal width (cm)', y='petal length (cm)', palette=antV, hue='class_name')

 

上述三种类别大致都呈现一个正相关的线性关系,也就是说花的宽度越宽,它的长度也就越长。

 2.16、热图绘制相关性
fig=plt.gcf()
fig.set_size_inches(12, 8)
fig=sns.heatmap(data[iris.feature_names].corr(), annot=True, linewidths=1, linecolor='k', \
                square=True, mask=False, vmin=-1, vmax=1, cbar_kws={"orientation": "vertical"}, cbar=True)

验证了上述所作的线性相关分析,其中长度和宽度具有很强的相关性,花瓣的长和宽相关性可以达到0.96.

 3、搭建模型进行分类

3.1、使用KNN模型

在 scikit-learn 库中,提供了多种分类方法供我们选择。本文将重点介绍 k-最近邻(k-NN)分类器,这是一种直观且易于理解的算法。

当需要对一个新数据点进行分类预测时,k-NN 算法会在训练数据集中寻找与该新数据点距离最近的点。一旦找到,新数据点的类别将被赋予与这些最近点相同的标签。

在 k-NN 算法中,参数 k 表示我们考虑的最近邻居的数量。例如,我们可以选取距离新数据点最近的 3 个或 5 个邻居,而不是单一的最近邻居。之后,算法将统计这些邻居中出现次数最多的类别,并以此作为新数据点的预测类别。
 

 对数据集进行划分,我们划分为训练集和测试集,其中用测试集来验证模型的效果。

from sklearn.model_selection import train_test_split
X = data[iris.feature_names]        # 等价于iris_dataset.data
y = data['label']     # 等价于iris_dataset.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

搭建模型并使用训练集训练模型

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train)

使用测试集验证模型。并计算出准确率

y_pred = knn.predict(X_test)
print("Test set predictions: \n {}".format(y_pred))
# 三种计算准确率的方法
print("Test set score: {:.2f}".format(np.mean(y_pred == y_test)))
print('Test set score: {:.2f}'.format(metrics.accuracy_score(y_pred, y_test)))
print("Test set score: {:.2f}".format(knn.score(X_test, y_test)))
3.2、使用不同的模型进行对比
from sklearn.linear_model import LogisticRegression 
from sklearn.neighbors import KNeighborsClassifier
from sklearn import svm
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics 

model_names=["SVM_linear","SVM_rbf","LR","DT","KNN"]
models=[svm.SVC(kernel='linear'),svm.SVC(kernel='rbf'),LogisticRegression(),DecisionTreeClassifier(),KNeighborsClassifier(n_neighbors=3)]

for i in range(len(model_names)):
    models[i].fit(X_train,y_train)
    prediction = models[i].predict(X_test)
    print("模型:"+model_names[i],"准确率:",metrics.accuracy_score(y_test,prediction))

模型:SVM_linear 准确率: 0.9736842105263158

模型:SVM_rbf 准确率: 0.9736842105263158

模型:LR 准确率: 0.9736842105263158

模型:DT 准确率: 0.9736842105263158

模型:KNN 准确率: 0.9736842105263158

 4、使用SHAP进行可解释性分析

SHAP 属于模型事后解释的方法,它的核心思想是计算特征对模型输出的边际贡献,再从全局和局部两个层面对“黑盒模型”进行解释。SHAP构建一个加性的解释模型,所有的特征都视为“贡献者”。对于每个预测样本,模型都产生一个预测值,SHAP value就是该样本中每个特征所分配到的数值。

4.1、对单个样本进行分析
import xgboost
import shap
shap.initjs()

# train an XGBoost model
X, y = shap.datasets.iris()
model = xgboost.XGBRegressor().fit(X, y)

# explain the model's predictions using SHAP
# (same syntax works for LightGBM, CatBoost, scikit-learn, transformers, Spark, etc.)
explainer = shap.Explainer(model)
shap_values = explainer(X)
shap_exp = shap.Explanation(shap_values)

# visualize the first prediction's explanation
shap.plots.waterfall(shap_values[2])

4.2、 绘制beeswarm蜂群图
shap.plots.beeswarm(shap_values)

花瓣的长度,特征区分比较明显,说明区分性能较好

 4.3、绘制heatmap热图
shap.plots.heatmap(shap_values)

其中f(x)输出为0,1,2时花瓣的长度由蓝色到浅红到红色进行变化,说明该特征确实是对分类起到了很好的效果。

 欢迎关注下方【GZH】,一起交流讨论!

  • 20
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值