莺尾花分类
一、实验简介
1、目标:
构建一个机器学习模型,从已知品种的莺尾花测量数据中进行学习,从而能够预测新莺尾花的品种。
2、莺尾花测量数据:
花瓣长度、花瓣宽度、花萼长度、花萼宽度。单位:厘米(cm)。
莺尾花品种:setosa、versicolor、virginica 三个品种之一。
莺尾花图片:
3、问题分类:
因为我们有已知品种的莺尾花的测量数据,所以这是一个监督学习问题。在这个问题中,我们要在多个选项中预测其中一个(莺尾花的品种)。这是一个分类(classification)问题的示例。可能的输出(莺尾花的不同品种)叫做类别(class)。数据集中的每朵莺尾花都属于三个类别之一,所以这是一个三分类问题。
二、实验
1、初识数据
在本例中我们用到了莺尾花(Iris)数据集,这是机器学习和统计学中一个经典的数据集。它包含在 scikit-learn 的 datasets 模块中。我们可以调用 load_iris 函数来加载数据:
load_iris 返回的 iris 对象是一个 Banch 对象,与字典非常相似,里面包含键和值:
(1) DESRC 键对应的值是数据集的简要说明:
(2) target_names 键对应的值是一个字符串数组,里面包含我们要预测的花的品种:
(3) feature_names 键对应的值是一个字符串列表,对每一个特征进行了说明:
(4) 数据包含在 target 和 data 字段中。data 里面是 花萼长度、花萼宽度、花瓣长度、花瓣宽度 的测量数据,格式为 NumPy 数组:
data 数组的每一行对应一朵花,列代表每朵花的四个测量数据:
可以看出,数组中包含 150 朵不同花的测量数据。
前面说过,机器学习中的个体叫做样本(sample),其属性叫做特征(feature)。
data 数组的形状(shape)是样本数乘以特征数。这是 scikit-learn 中的约定,我们的数据形状应始终遵循这个约定。
下面给出了前 5 个样本的特征值:
从数据中可以看出,前 5 朵花的花瓣宽度都是 0.2cm,第一朵花的花萼最长,是 5.1cm。
(5) target 数组包含的是测量过的每朵花的品种,也是一个 NumPy 数组:
target 是一维数组,每朵花对应其中一个数据;品种被转换成从 0 到 2 的整数:
上述数字的代表含义由 iris.get("target_names") 数组给出:0 代表 setosa,1 代表 versicolor,2 代表 virginica。
2、衡量模型是否成功:训练数据与测试数据
我们需要知道模型的 模型的 泛化(generalize)能力如何(换句话说,在新数据上能否正确预测)。
我们要用新数据来评估模型的性能。新数据是指模型之前没有见过的数据,而我们有这些新数据的标签。通常的做法是将收集好的带标签的数据(此例中是 150 朵花的测量数据)分成两部分:
a、训练数据(training data)、训练集(training set):
用于构建机器学习模型的数据,叫作 训练数据(training data)、训练集(training set)。
b、测试数据(test data)、训练集(test set)、留出集(hold-out set):
用来评估模型性能的数据,叫作 测试数据(test data)、训练集(test set)、留出集(hold-out set)。
本实验中,我们通过使用 scikit-learn 中的 train_test_split 函数将数据集打算并进行拆分。将 75% 的行数据及对应标签作为训练集,剩下 25% 的数据及标签作为测试集。训练集与测试集的分配比例可以是随意的,但使用 25% 的数据作为测试集是很好的经验法则。
scikit-learn 中的数据通常用大写的 X 表示,而标签用小写的 y 表示。这是受到了数学标准公式 的启发,其中 是函数的输入, 是输出。我们用大写的 X 是因为数据是一个二维数组(矩阵),用小写的 是因为目标是一个一维数组(向量),这也是数学种的约定。
对数据调用 train_test_split,并对输出结果采用下面这种命名方法:
为了确保多次运行同一函数能够得到相同的输出,我们利用 random_state 参数指定了随机数生成器的种子。这样函数输出就是固定不变的,所以这行代码的输出始终相同。
X_train 包含 75% 的行数据,X_test 包含剩下的25%:
3、要事第一:观察数据
在构建机器学习模型之前,通常最好检查一下数据,看看如果不用机器学习能不能轻松完成任务,或者需要的信息有没有包含在数据中。
此外,检查数据也是发现异常值和特殊值的好办法。举个例子,可能有些莺尾花的测量单位是英寸而不是厘米。在现实世界中,经常会遇到不一致的数据和意料之外的测量数据。
检查数据的最佳方法之一就是将其可视化。一种可视化方法是绘制散点图(scatter plot)。数据散点图将一个特征作为 轴,另一个特征作为 轴,将每一个数据点绘制为图上的一个点。不幸的是,计算机屏幕只有两个维度,所以我们一次只能绘制两个特征(也可能是 3 个)。用这种方法难以对多于 3 个特征的的数据集作图。解决这个问题的一种方法是绘制散点图矩阵(pair plot),从而可以两两查看所有的特征。如果特征数不多的话,比如我们这里有 4 个,这种方法是很合理的。但是我们应该记住,散点图矩阵无法同时显示所有特征之间的关系,所以这种可视化方法可能无法展示数据的某些有趣内容。
Iris 数据集的散点图矩阵,按类别标签着色(数据点的颜色与莺尾花的品种相对应):
从图中可以看出,利用花瓣和花萼的测量数据基本可以将三个类别区分开。这说明机器学习模型很可能可以学会区分它们。
4、构建第一个模型:k 近邻算法
k 近邻算法:我们可以考虑训练集中与新数据点最近的任意 k 个邻居(比如说,距离最近的 3 个或 5 个邻居),而不是只考虑最近的那一个。然后,我们可以用这些邻居中数量最多的类别做出预测。
scikit-learn 中所有机器学习模型都在各自的类中实现,这些类被称为 Estimator 类。k 近邻分类算法是在 neighbors 模块的 KNeighborsClassifier 类中实现的。我们需要将这个类实例化为一个对象,然后才能使用这个模型。这时我们需要设置模型的参数。KNeighborsClassifier 最重要的参数就是邻居的数目,这里我们设为1:
kun 对象对算法进行了封装,既包括用训练数据构建模型的算法,也包括对新数据点进行预测的算法。它还包括算法从训练数据中提取的信息。对于 KNeighborsClassifier 来说,里面只保存了训练集。
想要基于训练集来构建模型,需要调用 knn 对象的 fit 方法,输入参数为 X_train 和 y_train,二者都是 NumPy 数组,前者包含训练数据,后者包含相应的训练标签:
fit 方法返回的是 knn 对象本身并做原处修改,因此我们得到了分类器的字符串表示。从中可以看出构建模型时用到的参数。几乎所有参数都是默认值,但我们也会注意到 n_neighbors=1,这是我们传入的参数。scikit-learn 中的大多数模型都有很多参数,但多用于速度优化或非常特殊的用途。
5、做出预测
预测情景:我们在野外发现了一朵莺尾花,花萼长 5cm 宽 2.9cm,花瓣长 1cm 宽 0.2cm。
预测问题:这朵莺尾花属于哪个品种?
我们可以将这些数据放在一个 NumPy 数组中,再次计算形状,数组形状为样本数(1)乘以特征数(4):
注意:这里将这朵花的测量数据转换为二维 NumPy 数组的一行,因为 scikit-learn 的输入数据必须是二维数组。
调用 knn 对象的 predict 方法进行预测:
6、评估模型
对测试数据中的每朵莺尾花进行预测,并将预测结果与标签(已知的品种)进行对比。
通过计算精度(accuracy)来衡量模型的优劣,精度就是品种预测正确的花所占的比例:
还可以通过 knn 对象的 score 方法来计算测试集的精度:
对于这么模型来说,测试集的精度约为 0.97。
7、小结与展望
(1) 实验目的:利用莺尾花的物理测量数据来预测其品种。
(2) 实验过程分析:
a、我们在构建模型时用到了由专家标注过的测量数据集,专家已经给出了花的正确品种,因此这是一个监督学习问题。
b、一共有三个品种:setosa、versicolor、virginica,因此这是一个三分类问题。
c、在分类问题中,可能的品种被称为类别(class),每朵花的品种被称为标签(label)。
d、莺尾花(Iris) 数据集包含两个 NumPy 数组:
一个包含数据,在 scikit-learn 中被称为 X;一个包含正确的输出或预期输出,称为 y。
数组 X 是特征的二维数组,每个数据点对应一行,每个特征对应一列。
数组 y 是一维数组,里面包含一个类别标签,对每个样本都是 0 到 2 之间的整数。
e、我们将数据集分成训练集(training set)和测试集(test set),前者用于构建模型,后者用于评估模型对前所未见的新数据的泛化能力。
f、我们选择了 k 近邻分类算法,根据新数据点在训练集中距离最近的邻居来进行预测。
该算法在 KNeighborsClassifier 类中实现,里面既包含构建模型的算法,也包含利用模型进行预测的算法。
我们将类实例化,并设定参数。
然后调用 fit 方法来构建模型,传入训练数据(X_train)和训练输出(y_train)作为参数。
我们利用 score 方法来评估模型,该方法计算的是模型精度。我们将 score 方法用于测试数据集和测试集标签,得出模型的精度约为 97%,也就是说,该模型在测试集上 97% 的预测都是正确的。
这让我们有信心将模型应用于新数据(在我们的例子中是新花的测量数据),并相信模型在约 97% 的情况下都是正确的。
(3) 实验代码: