机器学习实战--二分类(MNIST数据集)

import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')

Scikit-Learn 加载数据集通常具有类似于字典的结构,包括:

DESCR:描述数据集

data:包含一个数组,每个实例为一行,每个特征为一列

target:包含一个带有标记的数组

X,y = mnist["data"], mnist["target"]



# 获得第一个数据 
some_digit=X[0]

some_digit
    array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  51, 159, 253,
           159,  50,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  48, 238,
           252, 252, 252, 237,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  54,
           227, 253, 252, 239, 233, 252,  57,   6,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  10,
            60, 224, 252, 253, 252, 202,  84, 252, 253, 122,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0, 163, 252, 252, 252, 253, 252, 252,  96, 189, 253, 167,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,  51, 238, 253, 253, 190, 114, 253, 228,  47,  79, 255,
           168,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,  48, 238, 252, 252, 179,  12,  75, 121,  21,   0,
             0, 253, 243,  50,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,  38, 165, 253, 233, 208,  84,   0,   0,   0,
             0,   0,   0, 253, 252, 165,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   7, 178, 252, 240,  71,  19,  28,   0,
             0,   0,   0,   0,   0, 253, 252, 195,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,  57, 252, 252,  63,   0,   0,
             0,   0,   0,   0,   0,   0,   0, 253, 252, 195,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0, 198, 253, 190,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0, 255, 253, 196,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  76, 246, 252,
           112,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 253, 252,
           148,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  85,
           252, 230,  25,   0,   0,   0,   0,   0,   0,   0,   0,   7, 135,
           253, 186,  12,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,  85, 252, 223,   0,   0,   0,   0,   0,   0,   0,   0,   7,
           131, 252, 225,  71,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,  85, 252, 145,   0,   0,   0,   0,   0,   0,   0,
            48, 165, 252, 173,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,  86, 253, 225,   0,   0,   0,   0,   0,
             0, 114, 238, 253, 162,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,  85, 252, 249, 146,  48,  29,
            85, 178, 225, 253, 223, 167,  56,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,  85, 252, 252, 252,
           229, 215, 252, 252, 252, 196, 130,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  28, 199,
           252, 252, 253, 252, 252, 233, 145,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,  25, 128, 252, 253, 252, 141,  37,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             0,   0,   0,   0], dtype=uint8)
# 图片是28*28像素的
some_digit_image=some_digit.reshape(28,28)
some_digit_image
    array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,  51, 159, 253, 159,  50,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,  48, 238, 252, 252, 252, 237,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
             54, 227, 253, 252, 239, 233, 252,  57,   6,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  10,  60,
            224, 252, 253, 252, 202,  84, 252, 253, 122,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 163, 252,
            252, 252, 253, 252, 252,  96, 189, 253, 167,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  51, 238, 253,
            253, 190, 114, 253, 228,  47,  79, 255, 168,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,  48, 238, 252, 252,
            179,  12,  75, 121,  21,   0,   0, 253, 243,  50,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,  38, 165, 253, 233, 208,
             84,   0,   0,   0,   0,   0,   0, 253, 252, 165,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   7, 178, 252, 240,  71,  19,
             28,   0,   0,   0,   0,   0,   0, 253, 252, 195,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,  57, 252, 252,  63,   0,   0,
              0,   0,   0,   0,   0,   0,   0, 253, 252, 195,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0, 198, 253, 190,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0, 255, 253, 196,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  76, 246, 252, 112,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0, 253, 252, 148,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  85, 252, 230,  25,   0,   0,   0,
              0,   0,   0,   0,   0,   7, 135, 253, 186,  12,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  85, 252, 223,   0,   0,   0,   0,
              0,   0,   0,   0,   7, 131, 252, 225,  71,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  85, 252, 145,   0,   0,   0,   0,
              0,   0,   0,  48, 165, 252, 173,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  86, 253, 225,   0,   0,   0,   0,
              0,   0, 114, 238, 253, 162,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  85, 252, 249, 146,  48,  29,  85,
            178, 225, 253, 223, 167,  56,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  85, 252, 252, 252, 229, 215, 252,
            252, 252, 196, 130,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,  28, 199, 252, 252, 253, 252, 252,
            233, 145,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,  25, 128, 252, 253, 252, 141,
             37,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0],
           [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
              0,   0]], dtype=uint8)
# cmap 颜色图谱(colormap), 默认绘制为RGB(A)颜色空间。 binary
plt.imshow(some_digit_image,cmap="binary")
# 不显示坐标轴、坐标轴标签,等价于布尔值False
#plt.axis("off")

在这里插入图片描述

y[0]
 0.0
# from sklearn.model_selection import  train_test_split
# Xtrain, Xtest, Ytrain, Ytest=train_test_split(mnist.data,mnist.target,train_size=0.7)


# 数据集划分
Xtrain, Xtest, Ytrain, Ytest=X[:60000],X[60000:],y[:60000],y[60000:]

y_train_5=(Ytrain==5)
y_train_5
	array([False, False, False, ..., False, False, False])
y_test_5=(Ytest==5)
y_test_5
	array([False, False, False, ..., False, False, False])
# 挑选分类器
from sklearn.linear_model import SGDClassifier # 随机梯度下降分类器 SGD

sgd_clf=SGDClassifier(random_state=42) #SGD训练时完全随机,如果希望得到可复现的结果 需要设置参数random_state
sgd_clf.fit(Xtrain,y_train_5)

sgd_score=sgd_clf.score(Xtest,y_test_5)

sgd_score
	0.968
# 用第一张图测试
sgd_clf.predict([some_digit])
	array([False])
# 分层抽样
from sklearn.model_selection import StratifiedKFold  
from sklearn.base import clone



# n_splits:折叠次数,默认为3,至少为2。
# shuffle:是否在每次分割之前打乱顺序。
# random_state:随机种子,在shuffle==True时使用,默认使用np.random。

skfolds = StratifiedKFold(n_splits=3,random_state=42)

for train_index,test_index in skfolds.split(Xtrain,y_train_5):
    print("train_index",train_index)
    clone_clf = clone(sgd_clf)  # 分类器副本
    Xtrain_folds=Xtrain[train_index]
    print("Xtrain_folds",Xtrain_folds)
    y_tarin_folds=y_train_5[train_index]
    X_test_fold=Xtrain[test_index]
    y_test_fold=y_train_5[test_index]
    
    clone_clf.fit(Xtrain_folds,y_tarin_folds)
    y_pred=clone_clf.predict(X_test_fold)
    n_correct=sum(y_pred==y_test_fold)
    print(n_correct/len(y_pred))
    train_index [18193 18194 18195 ... 59997 59998 59999]
    Xtrain_folds [[0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     ...
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]]
    

    D:\Anaconda3\envs\sklearn\lib\site-packages\sklearn\linear_model\stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in <class 'sklearn.linear_model.stochastic_gradient.SGDClassifier'> in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.
      "and default tol will be 1e-3." % type(self), FutureWarning)
    

    0.869
    train_index [    0     1     2 ... 59997 59998 59999]
    Xtrain_folds [[0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     ...
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]]
    

    D:\Anaconda3\envs\sklearn\lib\site-packages\sklearn\linear_model\stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in <class 'sklearn.linear_model.stochastic_gradient.SGDClassifier'> in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.
      "and default tol will be 1e-3." % type(self), FutureWarning)
    

    0.83725
    train_index [    0     1     2 ... 41804 41805 41806]
    Xtrain_folds [[0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     ...
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]
     [0 0 0 ... 0 0 0]]
    

    D:\Anaconda3\envs\sklearn\lib\site-packages\sklearn\linear_model\stochastic_gradient.py:128: FutureWarning: max_iter and tol parameters have been added in <class 'sklearn.linear_model.stochastic_gradient.SGDClassifier'> in 0.19. If both are left unset, they default to max_iter=5 and tol=None. If tol is not None, max_iter defaults to max_iter=1000. From 0.21, default max_iter will be 1000, and default tol will be 1e-3.
      "and default tol will be 1e-3." % type(self), FutureWarning)
    

    0.90535

`

# cross_val_score() K-折交叉验证
from sklearn.model_selection import cross_val_score

cross_val_score(sgd_clf,Xtrain,y_train_5,cv=10,scoring='accuracy')
    array([0.9481753 , 0.963     , 0.96433333, 0.9465    , 0.94833333,
           0.96633333, 0.88166667, 0.97016667, 0.91466667, 0.96866144])
# 评估分类器性能的最好方法是混淆矩阵
from sklearn.model_selection import cross_val_predict
# cross_val_predict和cross_val_score同样执行K-折交叉验证,但返回的不是评估分数,而是每个折叠的预测。这意味着对于每个实例都可以得到一个干净的预测
# 干净的意思是模型预测是使用的数据在其训练期间从未见过
y_train_pred=cross_val_predict(sgd_clf,Xtrain,y_train_5,cv=3)

y_train_pred
    array([False,  True, False, ..., False, False, False])
# 获得混淆矩阵
from sklearn.metrics import confusion_matrix
confusion_matrix(y_train_5,y_train_pred)
    array([[47700,  6879],
           [  889,  4532]], dtype=int64)



               预测
            非5     5
    非5   47700   6879
     5    889     4532
# 精度和召回率
from sklearn.metrics import precision_score,recall_score
# 精度 预测为正的样例中有多少是真正的正样例
precision_score(y_train_5,y_train_pred)
    0.3971606344755061
# 召回率 样本中的正例有多少被预测正确
recall_score(y_train_5,y_train_pred)
    0.8360081165836561
# f1_score F1分数可以看作是模型精确率和召回率的一种调和平均 
from sklearn.metrics import f1_score
f1_score(y_train_5,y_train_pred)
    0.5384980988593155
#精度/召回率权衡

# SGDClassifier 对于每个实例,他都会基于决策函数计算出一个分值,如果该值大于阈值 判为正类 反之为父类
# decision_function 返回每个实例的分数
y_scores=sgd_clf.decision_function([some_digit])

y_scores
    array([-316208.02724901])
# 手动设置阈值为0
threshold=0
y_some_digit_pred=(y_scores>threshold)

y_some_digit_pred
    array([False])
# cross_val_predict获取训练集中所有实例的分数
y_scores=cross_val_predict(sgd_clf,Xtrain,y_train_5,cv=10,method="decision_function")

# P-R 曲线
from sklearn.metrics import precision_recall_curve
# precision_recall_curve 计算所有可能的阈值的精度和召回率
predictions,recalls,thresholds=precision_recall_curve(y_train_5,y_scores)

# 精度
predictions.size
    59635
predictions[:-1].size
    59634
# 召回率
recalls.size
    59635
# 阈值
thresholds.size
    59634
def plot_precision_recall_vs_threshold(predictions,recalls,thresholds):
    plt.plot(thresholds,predictions[:-1],"b--",label="Precision")
    plt.plot(thresholds,recalls[:-1],"g-",label="Recall")


plot_precision_recall_vs_threshold(predictions,recalls,thresholds)

在这里插入图片描述

import numpy as np
# np.argmax 会返回最大值的第一个索引 在这种情况下,他表示第一个True的值
threshold_90_precision=thresholds[np.argmax(predictions>=0.9)]

threshold_90_precision
    150261.7279476063
y_train_pred_90=(y_scores>=threshold_90_precision)

precision_score(y_train_5,y_train_pred_90)
    0.9002298850574713
recall_score(y_train_5,y_train_pred_90)
    0.36118797269876407
#ROC曲线

from sklearn.metrics import roc_curve

fpr,tpr,thresholds = roc_curve(y_train_5,y_scores)

def plot_roc_curve(fpr,tpr,label=None):
    plt.plot(fpr,tpr,linewidth=2,label=label)
    plt.plot([0,1],[0,1],'k--')
plot_roc_curve(fpr,tpr)

在这里插入图片描述

from sklearn.metrics import roc_auc_score
# AUC ROC曲线下的面积.
# 计算ROC AUC的值
roc_auc_score(y_train_5,y_scores)
    0.9509513817728653
# 使用随机森林进行分类
from sklearn.ensemble import RandomForestClassifier
forest_clf=RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf,Xtrain,y_train_5,cv=10,method="predict_proba")

y_probas_forest
    array([[0.9, 0.1],
           [0.8, 0.2],
           [1. , 0. ],
           ...,
           [1. , 0. ],
           [1. , 0. ],
           [1. , 0. ]])
y_scores_forest=y_probas_forest[:,1]
y_scores_forest
    array([0.1, 0.2, 0. , ..., 0. , 0. , 0. ])
fpr_forest,tpr_forest,thresholds_forest=roc_curve(y_train_5,y_scores_forest)

plt.plot(fpr,tpr,"b:",label="SGD")
plot_roc_curve(fpr_forest,tpr_forest,"Random Forest")
plt.legend(loc="lower right")

在这里插入图片描述

roc_auc_score(y_train_5,y_scores_forset)
    0.990131230702452
  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值