@参考Python 机器学习基础教程
鸢尾花分类
一个简单的机器学习应用,构建第一个模型。
对鸢尾花的分类,根据测量数据进行,该测量数据则为特征。测量数据:花瓣的长度和宽度、花萼的长度和宽度,所有测量结果的单位为cm
我们的目标是构建一个机器学习模型
因为有已知品种的鸢尾花的测试数据,所以这是一个监督学习问题。我们要在多个选项中预测其中一个(品种)。这是一个分类(classsification)问题。可能的输出(鸢尾花的不同品种)叫做类别(class)。数据集中共有三个类别(setosa、versicolor、virginica)。对于一个数据点来说,它的品种叫做标签(label)。
1、初识数据
这是一个机器学习和统计学中一个经典的数据集(Iris)。包含在scikit-learn的datasets模块中。可以通过调用load_iris函数来加载数据。
from sklearn.datasets import load_iris
iris_dataset = load_iris()
print("Keys of iris_dataset: \n{}".format(iris_dataset.keys()))
返回值与字典非常相似,里面包含键和值。
DESCR键对应的值时数据集的简要你说明。
print(iris_dataset['DESCR'][:393] + "\n...")
[9] 预测花的种类、[10] 字符串列表,对每个特征进行了说明、[11] 数据的类型、[12] 前五组数据内容
[11]数据的类型中可以看出,包含了150朵不同花的测量数据,机器学习中的个体叫做样本(sample),其属性叫做特征(feature)。data数据的形状(shape)是样本数乘以特征数。
输出品种,品种被转换成0到2的整数。
2、衡量模型是否成功:训练数据与测试数据
不能将构建模型的数据用于评估模型。因为我们的模型会一直记住整个训练集。无法告知模型的泛化(generalize)能力如何(换句话,在新数据上能否正确预测)。
所以,一部分数据用于构建机器学习模型,叫做训练数据(training data)或训练集(training set)。其余的数据用嘞评估模型性能,叫做测试数据(test data)、测试集(test set)或留出集(hold-out set)。
scikit-learn中的train_test_split函数可以打乱数据集并进行拆分。训练集->75%,测试集->25%。
利用andom_state参数指定,使得每次输出的数据是固定不变的。
3、观察数据
检查数据也是发现异常值和特殊值的好方法。
检查数据的最佳方法之一绘制散点图(scatter plot)。
对于多余3个特征的数据集作图,需要绘制散点图矩阵(pair plot)。
import pandas as pd
import mglearn
from pandas.plotting import scatter_matrix
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
grr = scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o',hist_kwds={'bins': 20}, s=60, alpha=.8, cmap = mglearn.cm3)
在绘制过程中,需要将NumPy数组转换成pandas DataFrame的格式
从图中可以看出,利用花瓣和花萼的测量数据基本可以将三个类别区分开。
4、构建第一个模型:K近邻算法(K Nearest neighbor algorithm)
scikit-learn中有许多可用的分类算法。构建此模型只需要保存训练集即可。对一个新的数据点做出预测,算法会在训练集中寻找与这个心数据点距离最近的数据点,然后将找到的数据点的标签赋值给这个新的数据点。
K近邻算法中k的含义是,我们可以考虑训练集中与新数据点最近的任意k个邻居,而不是只考虑最近的拿一个,通常使用奇数。
scikit-learn中所有的机器学习模型都在各自的类中实现,这些类被称为Estimator类。k近邻分类算法是在neighbors模块的KNeighborsClassifier类中实现的。
需要将这个类实例化为一个对象,然后才能使用这个模型。这时我们需要设置模型的参数。KNeighborsClassifier最终要的参数就是邻居的数目,这里我们设为1:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors = 1)
想要基于训练集来构建模型,需要调用knn对象的fit方法。
knn.fit(X_train, y_train)
5、做出预测
我们发现了一朵鸢尾花,花萼长5cm宽2.9cm,花瓣长1cm宽0.2cm。输入的数组必须是一个二维数组。
调用knn对象的predict方法进行预测。
prediction = knn.predict(X_new)
print("Prediction:{}".format(prediction))
print("Prediction:{}".format(iris_dataset['target_names'][prediction]))
6、评估模型
通常使用测试集进行测试,并将预测结果与标签进行对比,通过计算精度(accuracy)来衡量模型的优劣。
还可以通过使用knn对象的score方法来计算测试集的精度:
说明测试集的精度约为0.97,也就是说97%的概率是预测正确。
小结
fit、predict和score方法是scikit-learn监督学习模型中最常用的接口!!