本文使用KNN算法对sklearn中自带的鸢尾花数据集进行了分类。
鸢尾花数据集
安德森鸢尾花卉数据集(英文:Anderson’s Iris data set),也称鸢尾花卉数据集(英文:Iris flower data set)或Fisher鸢尾花卉数据集(英文:Fisher’s Iris data set),是一种多重变量分析的数据集。它最初是埃德加·安德森从加拿大加斯帕半岛的鸢尾属花朵中提取的形态学变异数据[1]。
其数据集包含了150个样本,都属于鸢尾属下的3个亚属,分别是山鸢尾、变色鸢尾和维吉尼亚鸢尾。每个样本都包含4项特征,即花萼和花瓣的长度和宽度,它们可用于样本的定量分析。基于这些特征,费雪发展了能够确定其属种的线性判别分析。
python代码实现
import numpy as np
import matplotlib as plt
import pandas as pd
from sklearn.datasets import load_iris
%matplotlib inline
iris_dataset = load_iris()
print("Keys of iris_dataset: \n{}".format(iris_dataset.keys()))
Keys of iris_dataset:
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])
# DESCR键对应的值对数据集进行了简要的说明
print(iris_dataset['DESCR'][:193] + "\n ...")
.. _iris_dataset:
Iris plants dataset
--------------------
**Data Set Characteristics:**
:Number of Instances: 150 (50 in each of three classes)
:Number of Attributes: 4 numeric, pre
...
# target_names键对应的值包括了花的品种
print("Target names: {}".format(iris_dataset['target_names']))
Target names: ['setosa' 'versicolor' 'virginica']
# feature_names键对应的值包括了花的特征
print("Feature names :{}".format(iris_dataset['feature_names']))
Feature names :['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# 输出数据的格式和形状
print("Type of data: {}".format(type(iris_dataset['data'])))
print("Shape of data: {}".format(iris_dataset['data'].shape))
Type of data: <class 'numpy.ndarray'>
Shape of data: (150, 4)
# 输出target的数据形式和大小
print("Type of target: {}".format(type(iris_dataset['target'])))
print("Shape of target: {}".format(iris_dataset['target'].shape))
Type of target: <class 'numpy.ndarray'>
Shape of target: (150,)
# 输出target
print("Target: \n{}".format(iris_dataset['target']))
Target:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
2 2]
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(iris_dataset['data'],iris_dataset['target'],random_state=0)
# 输出X_train,y_train的形状
print("X_train shape:{}".format(X_train.shape))
print("y_train shape:{}".format(y_train.shape))
X_train shape:(112, 4)
y_train shape:(112,)
# 输出X_test,y_test的形状
print("X_test shape:{}".format(X_test.shape))
print("y_test shape:{}".format(y_test.shape))
X_test shape:(38, 4)
y_test shape:(38,)
# 绘制散点图矩阵
iris_dataframe = pd.DataFrame(X_train,columns=iris_dataset.feature_names)
grr = pd.plotting.scatter_matrix(iris_dataframe,c=y_train,figsize=(15,15),marker='o',
hist_kwds={'bins':20},s=60, alpha=0.8)
下图为散点图矩阵,相较于散点图,散点图矩阵能够呈现出多个维度变量之间的关系,方便对于整个数据的结构特征有一个直观的把控。
可以看出,利用花瓣和花萼的数据基本可以将数据进行区分,说明机器学习的模型可能能学会区分它们。
# 使用KNN对鸢尾花进行分类
from sklearn.neighbors import KNeighborsClassifier
# knn对象中包括了训练数据构建模型的算法和预测的算法
knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train,y_train)
# 使用建立好的KNN模型对新的样本点进行预测
X_new = np.array([[5, 2.9, 1, 0.2]])
prediction = knn.predict(X_new)
print("Predicted target name:{}".format(iris_dataset['target_names'][prediction]))
运行结果为:
Predicted target name:[‘setosa’]
# 对模型的精度进行计算
print("Test set score :{:.2f}".format(knn.score(X_test,y_test)))
计算精度为:
Test set score :0.97
# 检测模型的其他指标
from sklearn.metrics import classification_report
y_pre = knn.predict(X_test)
print(classification_report(y_test,y_pre))
计算其他性能指标:
precision recall f1-score support
0 1.00 1.00 1.00 13
1 1.00 0.94 0.97 16
2 0.90 1.00 0.95 9
accuracy 0.97 38
macro avg 0.97 0.98 0.97 38
weighted avg 0.98 0.97 0.97 38
参考文献:
[1] :Edgar Anderson. The irises of the Gaspé Peninsula. Bulletin of the American Iris Society. 1935, 59: 2–5.