机器学习 —— PCA降维和交叉验证

一、PCA降维

为什么要降维:

  1. 数据在低维下更容易处理、更容易使用;
  2. 重要特征更能在数据中明确的显示出来;比如:只有两维或者三维的话,更便于可视化展示;
  3. 去除数据噪声,把存在着错误或异常(偏离期望值)的数据去除, 降低干扰
  4. 降低算法开销, 特征过多训练速度会比较慢.

导包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

手写数字识别

获取数据digits.csv

digits = pd.read_csv('../data/digits.csv')
digits

digits.shape
# (42000, 785)

data = digits.iloc[:, 1:].copy()
target = digits['label']

target.unique()
# array([1, 0, 4, 7, 3, 5, 8, 9, 2, 6], dtype=int64)

data.shape
# (42000, 784)

data.iloc[0].values.reshape(28, 28)
'''
array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0, 188, 255,  94,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0, 191, 250, 253,  93,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0, 123, 248, 253, 167,  10,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,  80, 247, 253, 208,  13,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,  29, 207, 253, 235,  77,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,  54, 209, 253, 253,  88,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  93, 254, 253, 238, 170,  17,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         23, 210, 254, 253, 159,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  16,
        209, 253, 254, 240,  81,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  27,
        253, 253, 254,  13,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  20, 206,
        254, 254, 198,   7,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 168, 253,
        253, 196,   7,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  20, 203, 253,
        248,  76,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  22, 188, 253, 245,
         93,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0, 103, 253, 253, 191,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  89, 240, 253, 195,  25,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  15, 220, 253, 253,  80,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  94, 253, 253, 253,  94,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  89, 251, 253, 250, 131,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0, 214, 218,  95,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=int64)
'''

plt.imshow(data.iloc[10].values.reshape(28, 28), cmap='gray')

 

特征太多, 训练速度会比较慢, 这个时候我们需要进行降维

  • from sklearn.svm import SVC
from sklearn.svm import SVC

划分数据集

  • from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split



x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

x_train.shape, x_test.shape
# ((33600, 784), (8400, 784))

不进行降维,会执行很长时间

svc = SVC()

# %timeit   计算耗时较小的代码,精确计算
# %time  计算耗时较大的代码

%time svc.fit(x_train, y_train)
'''
Wall time: 1min 47s
SVC()
'''

PCA降维

PCA : 主成分分析

from sklearn.decomposition import PCA

from sklearn.decomposition import PCA

# n_components=None : 保留多少个特征
pca = PCA(n_components=50, whiten=True)

# pca降维
x_train_pca = pca.fit_transform(x_train)
x_train_pca.shape
# (33600, 50)

svc = SVC()

%time svc.fit(x_train_pca, y_train)
'''
Wall time: 34.8 s
SVC()
'''

x_test.shape,  x_train_pca.shape
# ((8400, 784), (33600, 50))

# 预测
x_test_pca = pca.transform(x_test)

x_test_pca.shape
# (8400, 50)

svc.predict(x_test_pca)
# array([7, 7, 5, ..., 7, 1, 1], dtype=int64)

svc.score(x_test_pca, y_test)
# 0.9783333333333334

画图

  • 绘制100个数据

x_test_pca.shape
# (8400, 50)

x_test.shape
# (8400, 784)

y_pred = svc.predict(x_test_pca)

y_test.values[0]
# 7

y_pred[0]
# 7

plt.figure(figsize=(30, 50))

for i in range(60):
    axes = plt.subplot(10, 6, i+1)
    axes.imshow(x_test.iloc[i].values.reshape(28, 28))
    
    axes.set_title(f'Real: {y_test.values[i]}\nPred: {y_pred[i]}', fontsize=16)
    

二、交叉验证

  • 数据需要分为:
    • 训练数据:用于模型开发
    • 验证数据:用于验证相同模型的性能
  • 我们经常将数据集随机分为训练数据和测试数据,以开发机器学习模型。 训练数据用于训练ML模型,同一模型在独立的测试数据上进行测试以评估模型的性能。

  • 随着分裂随机状态的变化,模型的准确性也会发生变化,因此我们无法为模型获得固定的准确性。 测试数据应与训练数据无关,以免发生数据泄漏。 在使用训练数据开发ML模型的过程中,需要评估模型的性能。 这就是交叉验证数据的重要性。

  • 交叉验证: 减轻数据取样不平衡造成的误差, 评估不同模型的性能.

导包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

加载数据:

  • from sklearn.datasets import load_iris
from sklearn.datasets import load_iris

data, target = load_iris(return_X_y=True)

SVM模型有两个非常重要的参数C与gamma

  • C参数: 惩罚系数,即对误差的宽容度。

    • c越高,说明越不能容忍出现误差,容易过拟合。
    • C越小,容易欠拟合。
    • C过大或过小,泛化能力变差
  • gamma参数

    • gamma 是选择RBF函数作为kernel后,该函数自带的一个参数。
    • gamma 隐含地决定了数据映射到新的特征空间后的分布,
    • gamma越大,支持向量越少,
    • gamma值越小,支持向量越多。支持向量的个数影响训练与预测的速度
    • gamma设的太大,会造成只会作用于支持向量样本附近,对于未知样本分类效果很差,存在训练准确率可以很高,(如果让无穷小,则理论上,高斯核的SVM可以拟合任何非线性数据,但容易过拟合)而测试准确率不高的可能,就是通常说的过拟合;而如果设的过小,则会造成平滑效应太大,无法在训练集上得到特别高的准确率,也会影响测试集的准确率。
  • C,gamma相互独立,便于并行化进行

所以找到一对合适的C和gamma参数非常重要

找到合适的C和gamma

划分数据集

  • from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split

from sklearn.svm import SVC

x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

svc = SVC()
svc.fit(x_train, y_train)
display(svc.score(x_train, y_train),  svc.score(x_test, y_test))
'''
0.975
0.9333333333333333
'''

KFold : k-fold cross-validation

  • K-交叉验证, K折交叉验证
  • K表示等分成几份
  • 在k折交叉验证中,原始数据集被平均分为k个子部分或折叠。 从k折或组中,对于每次迭代,选择一组作为验证数据,其余(k-1)个组选择为训练数据。
  • 优点:
    • 该模型偏差低
    • 时间复杂度低
    • 整个数据集可用于训练和验证
  • 缺点:
    • 不适合不平衡数据集
  • from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.model_selection import KFold, StratifiedKFold

# n_splits=5 : 拆分成几份,默认5份
kf = KFold(n_splits=5)
kf
# KFold(n_splits=5, random_state=None, shuffle=False)

data.shape
# (150, 4)

kf.split(data)
# <generator object> 生成器对象

list( kf.split(data) )
'''
[(array([ 30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
          43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,
          69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,
          82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
          95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
         108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
          26,  27,  28,  29,  60,  61,  62,  63,  64,  65,  66,  67,  68,
          69,  70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,
          82,  83,  84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,
          95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
         108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46,
         47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
          26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
          39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
          52,  53,  54,  55,  56,  57,  58,  59,  90,  91,  92,  93,  94,
          95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107,
         108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
          26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
          39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
          52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
          65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
          78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
         103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
         116, 117, 118, 119])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
          26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
          39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
          52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
          65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
          78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
          91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
         104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
         117, 118, 119]),
  array([120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132,
         133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145,
         146, 147, 148, 149]))]
'''

data
'''
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])
'''

data[[0,2,1]]
'''
array([[5.1, 3.5, 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.9, 3. , 1.4, 0.2]])
'''

target[[0, 2, 1]]
# array([0, 0, 0])

# data : x_train  x_test
# target: y_train  y_test


score_list = [ ]

# data: ndarray
for train, test in kf.split(data):
    # 训练数据和测试数据的索引
    # display(train, test)
    
    x_train = data[train]
    y_train = target[train]
    x_test = data[test]
    y_test = target[test]
    
    # 训练
    svc = SVC()
    svc.fit(x_train, y_train)
    score = svc.score(x_test, y_test)
    print(score)
    score_list.append(score)

# 平均值
np.array(score_list).mean()
'''
1.0
1.0
0.8333333333333334
0.9333333333333333
0.7
0.8933333333333333
'''

StratifiedKFold

  • 分层的KFold
skf = StratifiedKFold(n_splits=5)

list( skf.split(data, target) )

# 分的更细
'''
[(array([ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,
          23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
          36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
          49,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
          72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
          85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  50,  51,  52,
          53,  54,  55,  56,  57,  58,  59, 100, 101, 102, 103, 104, 105,
         106, 107, 108, 109])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  20,  21,  22,
          23,  24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
          36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
          49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  70,  71,
          72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
          85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 120,
         121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([ 10,  11,  12,  13,  14,  15,  16,  17,  18,  19,  60,  61,  62,
          63,  64,  65,  66,  67,  68,  69, 110, 111, 112, 113, 114, 115,
         116, 117, 118, 119])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  30,  31,  32,  33,  34,  35,
          36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,  48,
          49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,
          62,  63,  64,  65,  66,  67,  68,  69,  80,  81,  82,  83,  84,
          85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
         111, 112, 113, 114, 115, 116, 117, 118, 119, 130, 131, 132, 133,
         134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([ 20,  21,  22,  23,  24,  25,  26,  27,  28,  29,  70,  71,  72,
          73,  74,  75,  76,  77,  78,  79, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
          26,  27,  28,  29,  40,  41,  42,  43,  44,  45,  46,  47,  48,
          49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,
          62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
          75,  76,  77,  78,  79,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
         111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
         124, 125, 126, 127, 128, 129, 140, 141, 142, 143, 144, 145, 146,
         147, 148, 149]),
  array([ 30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  80,  81,  82,
          83,  84,  85,  86,  87,  88,  89, 130, 131, 132, 133, 134, 135,
         136, 137, 138, 139])),
 (array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
          13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
          26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
          39,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,  60,  61,
          62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,
          75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,
          88,  89, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110,
         111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
         124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136,
         137, 138, 139]),
  array([ 40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  90,  91,  92,
          93,  94,  95,  96,  97,  98,  99, 140, 141, 142, 143, 144, 145,
         146, 147, 148, 149]))]
'''

score_list = [ ]

# data: ndarray
for train, test in skf.split(data, target):
    # 训练数据和测试数据的索引
    # display(train, test)
    
    x_train = data[train]
    y_train = target[train]
    x_test = data[test]
    y_test = target[test]
    
    # 训练
    svc = SVC()
    svc.fit(x_train, y_train)
    score = svc.score(x_train, y_train)
    print(score)
    score_list.append(score)

# 平均值
np.array(score_list).mean()
'''
0.9833333333333333
0.9583333333333334
0.9833333333333333
0.9833333333333333
0.9583333333333334
0.9733333333333333
'''

GridSearchCV 网格搜索交叉验证 - 模型调参利器 (必须掌握)

  • GridSearchCV,它存在的意义就是自动调参,只要把参数输进去,就能给出最优化的结果和参数
  • 处理SVM的惩罚因子C,核函数kernel的gamma参数
  • 对于分类任务,使用StratifiedKFold, 对于其他任务,使用KFold
  • GridSearchCV : GridSearch CV: cross Validation
    • from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import GridSearchCV

使用SVC

# SVC

# estimator: 模型,不带参数
# param_grid: 参数网格
# n_jobs: 同时处理的任务数, -1表式使用所有处理器
# cv: 交叉验证的K值(分成K份),默认是5

svc = SVC()
param_grid = {
    'C': [0.01, 0.1, 1, 3, 5, 10],
    'gamma': [0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0]
}

gv = GridSearchCV(svc, param_grid, n_jobs=-1, cv=5)

# 训练
%time gv.fit(data, target)
'''
Wall time: 2.72 s
GridSearchCV(cv=5, estimator=SVC(), n_jobs=-1,
             param_grid={'C': [0.01, 0.1, 1, 3, 5, 10],
                         'gamma': [0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0]})
'''

# 最佳模型
best_svc = gv.best_estimator_
best_svc

# 最佳参数
gv.best_params_
# {'C': 3, 'gamma': 0.1}

# 最佳得分
gv.best_score_
# 0.9866666666666667

使用逻辑回归

  • from sklearn.linear_model import LogisticRegression

from sklearn.linear_model import LogisticRegression



# C
# solver
# max_iter

lr = LogisticRegression()

param_grid = {
    'C': [0.1, 0.4, 0.6, 0.8, 1.0, 5.0],
    'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'],
    'max_iter': [100, 1000, 2000]
}

gv = GridSearchCV(lr, param_grid, n_jobs=-1)


%time  gv.fit(data, target)
'''
Wall time: 5.46 s

GridSearchCV(estimator=LogisticRegression(), n_jobs=-1,
             param_grid={'C': [0.1, 0.4, 0.6, 0.8, 1.0, 5.0],
                         'max_iter': [100, 1000, 2000],
                         'solver': ['newton-cg', 'lbfgs', 'liblinear', 'sag',
                                    'saga']})
'''

gv.best_estimator_
# LogisticRegression(C=0.4, solver='sag')

gv.best_params_
# {'C': 0.4, 'max_iter': 100, 'solver': 'sag'}

gv.best_score_
# 0.9866666666666667

三、人脸识别

  1. 人脸有不同的颜色组成
  2. 不同的颜色由三原色:红绿蓝组成
  3. 红绿蓝由0~255的数字组成

所以人脸是由数据组成的

导包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

使用matplotlib.pyplot导入数据图片并显示,观察这张图片的数据类型

进行人脸识别操作

  1. 导入相应模块工具:GridSearchCV、fetch_lfw_people、PCA等
  2. 使用fetch_lfw_people导入数据,如果本地没有会从网络下载,如果本地有数据,加载本地
  3. 查看人脸数据结构
  4. 从人脸数据中提取进行机器学习的关键数据
  5. 对数据进行分割,获取训练数据和测试数据
  6. 数据太复杂了,使用PCA对数据进行降维处理,去除一些不重要的数据
  7. 定义方法获取预测人名和真实人名
  8. 定义方法绘制人脸识别结果图形
  9. 调用方法进行数据展示

获取人脸数据

  • from sklearn.datasets import fetch_lfw_people
  • faces = fetch_lfw_people(min_faces_per_person=70, resize=1)
from sklearn.datasets import fetch_lfw_people

# min_faces_per_person=70 : 同一个人需要至少70张图
# resize=1: 获取图片后是否缩放,1表式不缩放
faces = fetch_lfw_people(min_faces_per_person=70, resize=1)

# 第一次下载会比较慢


faces

data = faces['data']
images = faces['images']
target = faces['target']
target_names = faces['target_names']



data.shape
# (1288, 11750)

images.shape
# (1288, 125, 94)

pd.Series(target).unique()
# array([5, 6, 3, 1, 0, 4, 2], dtype=int64)

target_names
'''
array(['Ariel Sharon', 'Colin Powell', 'Donald Rumsfeld', 'George W Bush',
       'Gerhard Schroeder', 'Hugo Chavez', 'Tony Blair'], dtype='<U17')
'''

# plt.imshow(images[30], cmap='gray')

plt.imshow(data[60].reshape(125, 94), cmap='gray')

划分数据集

  • from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split


x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

x_train.shape, x_test.shape
# ((1030, 11750), (258, 11750))

PCA降维

  • from sklearn.decomposition import PCA
from sklearn.decomposition import PCA

pca = PCA(n_components=50)


# 训练数据降维
x_train_pca = pca.fit_transform(x_train)

x_train_pca.shape
# (1030, 50)


# 测试数据降维
x_test_pca = pca.transform(x_test)

x_test_pca.shape
# (258, 50)

选择SVC算法, 进行训练

  • from sklearn.svm import SVC
from sklearn.svm import SVC

svc = SVC()
svc.fit(x_train_pca, y_train)


# 预测
svc.predict(x_test_pca)
'''
array([2, 3, 3, 1, 2, 6, 3, 3, 3, 3, 3, 3, 2, 2, 1, 3, 3, 3, 6, 3, 3, 3,
       3, 6, 3, 1, 4, 3, 3, 3, 1, 3, 3, 2, 2, 3, 3, 3, 1, 3, 3, 3, 0, 5,
       1, 3, 2, 4, 1, 5, 3, 3, 6, 3, 6, 3, 3, 6, 1, 3, 1, 3, 3, 3, 1, 3,
       1, 3, 6, 3, 3, 3, 3, 1, 3, 3, 1, 6, 4, 5, 3, 1, 5, 1, 1, 3, 3, 2,
       3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 1, 1, 4, 2, 6, 3, 3, 1, 3, 3, 3, 3,
       3, 6, 3, 3, 1, 3, 4, 1, 0, 3, 6, 1, 3, 3, 1, 3, 3, 3, 3, 3, 6, 3,
       3, 3, 3, 3, 1, 1, 3, 3, 3, 3, 4, 3, 4, 6, 3, 3, 3, 1, 3, 3, 3, 6,
       3, 1, 1, 3, 3, 1, 3, 3, 0, 6, 3, 3, 3, 3, 3, 1, 1, 6, 3, 3, 1, 1,
       3, 3, 1, 1, 3, 3, 3, 2, 3, 3, 0, 3, 1, 1, 3, 1, 1, 3, 3, 3, 3, 3,
       1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 1, 0, 3, 3, 3, 1, 0, 2, 5,
       1, 5, 3, 3, 3, 3, 3, 3, 5, 3, 6, 6, 3, 4, 6, 6, 6, 3, 0, 1, 6, 0,
       6, 3, 3, 3, 2, 3, 1, 5, 3, 3, 3, 3, 2, 1, 3, 0], dtype=int64)
'''

svc.score(x_train_pca, y_train)
# 0.9233009708737864

# 得分
svc.score(x_test_pca, y_test)
# 0.7558139534883721

# 有点过拟合 : 训练集得分比较高,预测效果不好

GridSearchCV进行调参

from sklearn.model_selection import GridSearchCV


svc2 = SVC()
param_grid = {
    'C': [0.01, 0.05, 0.1, 0.5, 1.0, 5.0],
#     'gamma': [0.001, 0.01, 0.1, 1.0, 5.0]
}
gv = GridSearchCV(svc2, param_grid, n_jobs=-1)


# gv.fit(x_train_pca, y_train)
'''
GridSearchCV(estimator=SVC(), n_jobs=-1,
             param_grid={'C': [0.01, 0.05, 0.1, 0.5, 1.0, 5.0]})
'''

# gv.best_score_
# 0.7640776699029126

# gv.best_params_
# {'C': 5.0}

画图

# x_test 
x_test.shape
# (258, 11750)

x_test
'''
array([[125.       , 130.66667  , 138.33333  , ...,  43.333332 ,
         12.666667 ,   5.3333335],
       [ 53.333332 ,  50.666668 ,  53.666668 , ..., 126.       ,
         34.       ,  31.333334 ],
       [ 27.       ,  25.       ,  27.333334 , ..., 249.33333  ,
        245.33333  , 249.33333  ],
       ...,
       [140.33333  , 146.33333  , 138.       , ...,  69.666664 ,
         34.666668 ,  20.666666 ],
       [ 75.666664 ,  74.333336 ,  77.333336 , ..., 163.66667  ,
        167.33333  , 167.66667  ],
       [184.33333  , 181.66667  , 179.       , ...,  37.333332 ,
         38.       ,  38.       ]], dtype=float32)
'''

x_test_pca.shape
# (258, 50)

y_test
'''
array([2, 4, 0, 1, 2, 6, 3, 5, 3, 3, 3, 3, 2, 2, 1, 3, 3, 3, 6, 1, 3, 3,
       3, 6, 4, 1, 4, 5, 3, 1, 1, 3, 3, 2, 2, 3, 2, 3, 1, 3, 3, 3, 0, 5,
       1, 1, 2, 4, 1, 5, 3, 3, 6, 0, 6, 3, 6, 6, 5, 3, 1, 3, 2, 6, 1, 3,
       1, 3, 4, 3, 2, 3, 3, 1, 4, 6, 1, 6, 4, 5, 3, 1, 5, 1, 0, 3, 4, 2,
       3, 0, 3, 3, 3, 0, 3, 3, 3, 3, 1, 1, 6, 2, 5, 3, 0, 1, 4, 3, 3, 6,
       3, 6, 3, 3, 2, 3, 4, 0, 0, 3, 4, 1, 1, 3, 1, 4, 3, 3, 5, 3, 6, 3,
       3, 3, 3, 3, 6, 1, 3, 3, 3, 3, 4, 3, 4, 6, 3, 3, 3, 1, 3, 3, 5, 6,
       3, 1, 1, 3, 3, 1, 3, 6, 0, 6, 3, 3, 6, 3, 3, 2, 1, 6, 3, 1, 1, 1,
       3, 3, 1, 2, 4, 3, 5, 4, 2, 3, 0, 3, 1, 1, 3, 0, 1, 3, 3, 2, 3, 4,
       1, 1, 1, 3, 3, 4, 3, 3, 3, 3, 2, 4, 3, 1, 0, 6, 3, 3, 5, 0, 2, 3,
       1, 5, 3, 3, 3, 1, 3, 3, 5, 5, 6, 6, 3, 6, 5, 6, 6, 1, 0, 5, 6, 0,
       0, 3, 3, 3, 2, 3, 1, 4, 3, 2, 3, 3, 1, 1, 5, 0], dtype=int64)
'''

y_pred = svc.predict(x_test_pca)

y_pred.shape
# (258,)


# 画10个图
plt.figure(figsize=(5*3, 2*4))

for i in range(10):
    
    axes = plt.subplot(2, 5, i+1)
    axes.imshow(x_test[i].reshape(125, 94), cmap='gray')
    axes.axis('off')  # 去掉坐标轴
    axes.set_title(f'Real: {y_test[i]}\nPred: {y_pred[i]}', 
                      fontsize=16,
                       color='k' if y_test[i]==y_pred[i] else 'r'
                  )

读取网络数据进行灰度处理

  • gray = [0.299,0.587,0.114]

  • bush_gray = np.dot(bush, gray)

    如果是jpg图片进行归一化操作

data.shape
# (1288, 11750)

bush = plt.imread('../data/bush.jpg')
bush.shape
# (259, 460, 3)

bush
'''
array([[[139,  97,  72],
        [139,  97,  72],
        [139,  97,  72],
        ...,
        [188, 153, 134],
        [193, 158, 139],
        [196, 161, 142]],

       [[139,  97,  72],
        [139,  97,  72],
        [139,  97,  72],
        ...,
        [188, 153, 134],
        [193, 158, 139],
        [196, 161, 142]],

       [[139,  97,  72],
        [139,  97,  72],
        [139,  97,  72],
        ...,
        [188, 153, 134],
        [193, 158, 139],
        [196, 161, 142]],

       ...,

       [[  0,   9,   0],
        [  0,   7,   0],
        [  3,   5,   0],
        ...,
        [ 81,  59,  46],
        [ 81,  59,  46],
        [ 80,  58,  45]],

       [[  0,   9,   0],
        [  0,   7,   0],
        [  3,   5,   0],
        ...,
        [ 81,  59,  46],
        [ 81,  59,  46],
        [ 80,  58,  45]],

       [[  0,   9,   0],
        [  0,   7,   0],
        [  3,   5,   0],
        ...,
        [ 81,  59,  46],
        [ 80,  58,  45],
        [ 80,  58,  45]]], dtype=uint8)
'''

# 灰度处理
gray = [0.299,0.587,0.114]  
bush_gray = np.dot(bush, gray)

bush_gray.shape
# (259, 460)

plt.imshow(bush_gray, cmap='gray')
# plt.imshow(bush, cmap='gray')

把图片形状变成 (125, 94)

  • from scipy import ndimage

截取人脸

  • bush_face = bush_gray[20:160, 190:290]
from scipy import ndimage

bush_face = bush_gray[20:160, 190:290]
bush_face.shape
# (140, 100)

plt.imshow(bush_face, cmap='gray')

# 缩放
# (140, 100) 把图片形状变成 (125, 94)
bush_zoom = ndimage.zoom(bush_face, zoom=(125/140, 94/100))
bush_zoom.shape
# (125, 94)

plt.imshow(bush_zoom, cmap='gray')

PCA降维

bush_zoom.shape
# (125, 94)


bush_test = bush_zoom.reshape(1, -1)

bush_test.shape
# (1, 11750)


bush_test_pca = pca.transform(bush_test)

bush_test_pca.shape
# (1, 50)

预测

y_pred = svc.predict(bush_test_pca)
y_pred
# array([3], dtype=int64)

target_names
'''
array(['Ariel Sharon', 'Colin Powell', 'Donald Rumsfeld', 'George W Bush',
       'Gerhard Schroeder', 'Hugo Chavez', 'Tony Blair'], dtype='<U17')
'''

target_names[3]
# 'George W Bush'
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值