【机器学习基础】鸢尾花的分类 - 机器学习领域的Hello World

1 项目简介

【背景】
假设有一名植物学爱好者对她发现的鸢尾花的品种很感兴趣。她收集了每朵鸢尾花的一些测量数据:花瓣的长度和宽度以及花萼的长度和宽度,所有测量结果的单位都是厘米。她还有一些鸢尾花的测量数据,这些花之前已经被植物学专家鉴定为属于setosa、versicolor或virginica三个品种之一。对于这些测量数据,她可以确定每朵鸢尾花所属的品种。

【目标】构建一个机器学习模型,可以从上述已知品种的鸢尾花测量数据,从而预测新鸢尾花的品种

【分析】监督学习问题;分类问题;

【拓展】

  • 类别:可能输出(鸢尾花的不同品种)
  • 标签:单个数据点的预期输出
  • 样本:机器学习中的个体
  • 特征:样本属性

【补充】from…import…可能造成命名污染,不推荐过多使用

1.1 初识数据

【关键词】Bunch对象;load_iris;

from sklearn.datasets import load_iris
iris_dataset = load_iris()
print('Keys of iris dataset: \n{}'.format(iris_dataset.keys()))
Keys of iris dataset: 
dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 'feature_names', 'filename', 'data_module'])

【DESCR】其对应的值是数据集的简要说明

print(iris_dataset['DESCR']+'\n')
.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

【target_names】其对应的值是一个字符串数组,包含我们要预测的话的品种

print('Target names: {}'.format(iris_dataset['target_names']))
Target names: ['setosa' 'versicolor' 'virginica']

【feature_names】其对应的值是一个字符列表,对数据的每个特征进行了说明

print('Feature names: \n{}'.format(iris_dataset['feature_names']))
Feature names: 
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

【data】其里面是花萼长度、花萼宽度、花瓣长度、花瓣宽度,格式为Numpy数组

  • data数组的每一行对应一朵花,列代表每朵花的四个测试数据
  • data数组的形状是样本数与特征数的乘积
print('Type of data: {}'.format(type(iris_dataset['data'])))
Type of data: <class 'numpy.ndarray'>
print('Shape of data: {}'.format(iris_dataset['data'].shape))
Shape of data: (150, 4)
print('First five rows of data:\n{}'.format(iris_dataset['data'][:5]))
First five rows of data:
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]

【target】一维数组,每朵花对应其中一个数据,品种被转换成0到2的整数

  • 0 setosa
  • 1 versicolor
  • 2 virginica
print('Target:\n{}'.format(iris_dataset['target']))
Target:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]

1.2 训练数据与测试数据

【train_test_split】利用伪随机数生成器将数据集打乱,确保测试集有所有类别的数据

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
for Xy in list(zip(X_train, y_train))[:10]:
    print(Xy)
(array([5.9, 3. , 4.2, 1.5]), 1)
(array([5.8, 2.6, 4. , 1.2]), 1)
(array([6.8, 3. , 5.5, 2.1]), 2)
(array([4.7, 3.2, 1.3, 0.2]), 0)
(array([6.9, 3.1, 5.1, 2.3]), 2)
(array([5. , 3.5, 1.6, 0.6]), 0)
(array([5.4, 3.7, 1.5, 0.2]), 0)
(array([5. , 2. , 3.5, 1. ]), 1)
(array([6.5, 3. , 5.5, 1.8]), 2)
(array([6.7, 3.3, 5.7, 2.5]), 2)

【shape】查看训练集与测试集的大小

print('X_train shape: {}'.format(X_train.shape))
print('y_train shape: {}'.format(y_train.shape))
print()
print('X_test shape: {}'.format(X_test.shape))
print('y_test shape: {}'.format(y_test.shape))
X_train shape: (112, 4)
y_train shape: (112,)

X_test shape: (38, 4)
y_test shape: (38,)

1.3 观察数据

【目的】找出异常值和特殊值(也许是数据单位不统一)

【方法】可视化(如绘制散点图、散点图矩阵)

import pandas as pd
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
grr = pd.plotting.scatter_matrix(
    iris_dataframe, c=y_train, figsize=(15,15), marker='o',
    hist_kwds={'bins':20}, alpha=0.8)

在这里插入图片描述




2 构建模型:KNN算法

【概述】

  • k-近邻算法采用测量不同特征值之间距离的方法进行分类
  • k的含义:寻找训练集中与新数据最近的k个数据点

【补充】scikit-learn中所有机器学习模型都在各自类中实现

2.1 使用KNeighborsClassfier类的fit方法

import sklearn.neighbors as skln
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)

knn = skln.KNeighborsClassifier(n_neighbors=1)
print(knn.fit(X_train, y_train))
KNeighborsClassifier(n_neighbors=1)

2.2 预测新数据

【事件】我们在野外发现了一朵鸢尾花,花萼长5cm 宽2.9,花瓣长1cm 宽0.2cm。这朵鸢尾花是哪种品种捏?

【警告】这朵花的测试数据转化为二维numpy数组的第一行,请记住scikit-learn的输入数据必须是二维数组

import numpy as np
X_new = np.array([[5,2.9, 1,0.2]]) # 这是个二维numpy数组
print('X_new.shape: {}'.format(X_new.shape))
X_new.shape: (1, 4)

【初试】模型说这朵鸢尾花的标签为0,叫做setosa。它说是就是?验证模型的可信度也是十分重要的

from sklearn.datasets import load_iris

iris_dataset = load_iris()

prediction = knn.predict(X_new)
print('Prediction: {}'.format(prediction))
print('Predicted target_name: {}'.format(iris_dataset['target_names'][prediction]))
Prediction: [0]
Predicted target_name: ['setosa']

2.3 评估模型

【任务】我们可以计算品种预测正确的花所占的比例衡量模型的准确度

【提示】测试集:开始工作

y_pred = knn.predict(X_test)
print('Test_set predictions:\n{}'.format(y_pred))

print('Test_set score: {:.2f}'.format(np.mean(y_pred==y_test)))
Test_set predictions:
[2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2]
Test_set score: 0.97

【补充】KNeighborsClassifier类的score方法计算测试集的精度

print('Test_set score: {:.2f}'.format(knn.score(X_test, y_test)))
Test_set score: 0.97

【冷静分析】0.97意味着这个模型中有97%的数据是正确的。也就是说,对于之前输入的新数据,我有97%的把握认为模式猜得对




参考 鸢尾花分类

  • 3
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

维他命C++

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值