MNIST手写字体模型评价指标详解

1 数据处理

1.0 数据获取

from sklearn.datasets import fetch_mldata
'''在线下载'''
mnist = fetch_mldata("MNIST original")
print("MNIST datasets: {}".format(mnist))

1.2 数据展示

from sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
mnist = fetch_mldata("MNIST original", data_home="./datasets")
images, labels = mnist["data"], mnist["target"]
print("Images shape: {}, Labels shape: {}".format(images.shape, labels.shape))
test_image = images[36000]
test_label = labels[36000]
print("image label: {}".format(test_label))
plt.figure()
# plt.imshow(test_image.reshape(28, 28), cmap="Greys_r")
plt.imshow(test_image.reshape(28, 28), cmap=matplotlib.cm.binary, interpolation="nearest")
plt.show()

在这里插入图片描述

图1.0 数据展示

2 模型评价指标

2.1 原始数据

2.1.0 数据处理

from sklearn.datasets import fetch_mldata
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/usr/share/fonts/truetype/arphic/ukai.ttc")

import numpy as np
import os
TRAIN_STEPS = 10
mnist = fetch_mldata("MNIST original", data_home="./datasets")
images, labels = mnist["data"], mnist["target"]
train_images, train_labels, test_images, test_labels = images[:60000], labels[:60000], images[60000:], labels[60000:]
# train_images, train_labels, test_images, test_labels = images[:6], labels[:6], images[60000:], labels[60000:]
# shuffle_index = np.random.permutation(6)
# shuffle index: [5 3 1 2 0 4]
# print("shuffle index: {}".format(shuffle_index))
'''Create shuffle index for mix the train data'''
'''Get True or false whether label equal 5 or not.'''
train_labels_5 = (train_labels == 5)
# [False False False False False False]
# print("train labels 5: {}".format(train_label_5))
# print("Type of labels: {}".format(type(train_label_5)))
test_labels_5 = (test_labels == 5)
'''Create classifier'''
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(train_images, train_labels_5)
test_image = images[36000]
pre_res = sgd_clf.predict([test_image])
print("Predicted result: {}".format(pre_res))
cross_value = cross_val_score(sgd_clf, train_images, train_labels_5, cv=3, scoring="accuracy")
print("Modle prediction accuracy: {}".format(cross_value))
predict_labels = cross_val_predict(sgd_clf, train_images, train_labels_5, cv=3)
print("predict labels: {}".format(predict_labels))
print("shape of predict labels: {}".format(predict_labels.shape))
cfm = confusion_matrix(train_labels_5, predict_labels)
print("confusion matrix: {}".format(cfm))
precision_value = precision_score(train_labels_5, predict_labels)
print("Precision of model predition: {}".format(precision_value))
recall_value = recall_score(train_labels_5, predict_labels)
print("Recall of model prediction: {}".format(recall_value))
f1_value = f1_score(train_labels_5, predict_labels)
print("F1 score: {}".format(f1_value))
predict_values = cross_val_predict(sgd_clf, train_images, train_labels_5, cv=3, method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(train_labels_5, predict_values)
if not os.path.exists("./images"):
    os.makedirs("./images")

2.1.0 阈值和精度/召回率

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.figure(figsize=(10, 4))
    plt.plot(thresholds, precisions[:-1], "b--", label="精度")
    plt.plot(thresholds, recalls[:-1], "g-", label="召回率")
    plt.xlabel("阈值", fontproperties=font)
    plt.legend(loc="upper left", prop=font)
    plt.ylim([0, 1])
    plt.grid("on")
    plt.xlim([-1500000, 600000])
    plt.savefig("./images/pre_recall_threshold_shuffle.png", format="png")
    plt.show()
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)

在这里插入图片描述

图2.1 阈值和精度/召回率

2.1.2 精度和召回率

def plot_precision_recall(precisions, recalls):
    plt.figure(figsize=(6, 6))
    plt.plot(recalls[:-1], precisions[:-1], "g-")
    plt.xlabel("召回率", fontproperties=font)
    plt.ylabel("精度", fontproperties=font)
    plt.ylim([0, 1])
    plt.grid("on")
    plt.xlim([0, 1])
    plt.savefig("./images/precision_recall_shuffle.png", format="png")
    plt.show()
plot_precision_recall(precisions, recalls)

在这里插入图片描述

图2.2 精度和召回率

2.1.3 ROC

def plot_roc_curve(fpr, tpr, label=None):
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, linewidth=2, label=label)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.axis([0, 1, 0, 1])
    plt.xlabel("假正类率", fontproperties=font)
    plt.ylabel("真正类率", fontproperties=font)
    plt.grid("on")
    plt.savefig("./images/fpr_tpr_shuffle.png", format="png")
    plt.show()
plot_roc_curve(fpr, tpr)

在这里插入图片描述

图2.3 ROC曲线

2.2 数据shuffle处理

2.2.1 数据处理

from sklearn.datasets import fetch_mldata
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
font = FontProperties(fname="/usr/share/fonts/truetype/arphic/ukai.ttc")

import numpy as np
import os
TRAIN_STEPS = 10
mnist = fetch_mldata("MNIST original", data_home="./datasets")
images, labels = mnist["data"], mnist["target"]
train_images, train_labels, test_images, test_labels = images[:60000], labels[:60000], images[60000:], labels[60000:]
# train_images, train_labels, test_images, test_labels = images[:6], labels[:6], images[60000:], labels[60000:]
# shuffle_index = np.random.permutation(6)
# shuffle index: [5 3 1 2 0 4]
# print("shuffle index: {}".format(shuffle_index))
'''Create shuffle index for mix the train data'''
shuffle_index = np.random.permutation(60000)
train_images, train_labels = train_images[shuffle_index], train_labels[shuffle_index]
'''Get True or false whether label equal 5 or not.'''
train_labels_5 = (train_labels == 5)
# [False False False False False False]
# print("train labels 5: {}".format(train_label_5))
# print("Type of labels: {}".format(type(train_label_5)))
test_labels_5 = (test_labels == 5)
'''Create classifier'''
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(train_images, train_labels_5)
test_image = images[36000]
pre_res = sgd_clf.predict([test_image])
print("Predicted result: {}".format(pre_res))
cross_value = cross_val_score(sgd_clf, train_images, train_labels_5, cv=3, scoring="accuracy")
print("Modle prediction accuracy: {}".format(cross_value))
predict_labels = cross_val_predict(sgd_clf, train_images, train_labels_5, cv=3)
print("predict labels: {}".format(predict_labels))
print("shape of predict labels: {}".format(predict_labels.shape))
cfm = confusion_matrix(train_labels_5, predict_labels)
print("confusion matrix: {}".format(cfm))
precision_value = precision_score(train_labels_5, predict_labels)
print("Precision of model predition: {}".format(precision_value))
recall_value = recall_score(train_labels_5, predict_labels)
print("Recall of model prediction: {}".format(recall_value))
f1_value = f1_score(train_labels_5, predict_labels)
print("F1 score: {}".format(f1_value))
predict_values = cross_val_predict(sgd_clf, train_images, train_labels_5, cv=3, method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(train_labels_5, predict_values)
if not os.path.exists("./images"):
    os.makedirs("./images")

2.2.2 阈值和精度/召回率

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.figure(figsize=(10, 4))
    plt.plot(thresholds, precisions[:-1], "b--", label="精度")
    plt.plot(thresholds, recalls[:-1], "g-", label="召回率")
    plt.xlabel("阈值", fontproperties=font)
    plt.legend(loc="upper left", prop=font)
    plt.ylim([0, 1])
    plt.grid("on")
    plt.xlim([-1500000, 600000])
    plt.savefig("./images/pre_recall_threshold_shuffle.png", format="png")
    plt.show()
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)

在这里插入图片描述

图2.1 阈值和精度/召回率

2.2.3 精度和召回率

def plot_precision_recall(precisions, recalls):
    plt.figure(figsize=(6, 6))
    plt.plot(recalls[:-1], precisions[:-1], "g-")
    plt.xlabel("召回率", fontproperties=font)
    plt.ylabel("精度", fontproperties=font)
    plt.ylim([0, 1])
    plt.grid("on")
    plt.xlim([0, 1])
    plt.savefig("./images/precision_recall_shuffle.png", format="png")
    plt.show()
plot_precision_recall(precisions, recalls)

在这里插入图片描述

图2.2 精度和召回率

2.2.4 ROC

def plot_roc_curve(fpr, tpr, label=None):
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, linewidth=2, label=label)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.axis([0, 1, 0, 1])
    plt.xlabel("假正类率", fontproperties=font)
    plt.ylabel("真正类率", fontproperties=font)
    plt.grid("on")
    plt.savefig("./images/fpr_tpr_shuffle.png", format="png")
    plt.show()
plot_roc_curve(fpr, tpr)

在这里插入图片描述

图2.3 ROC曲线

3 总结

(1) 原始数据顺序对训练结果有直接影响,因此训练模型时需要对数据进行洗牌(shuffle);
(2) ROC曲线是灵敏度和(1-特异度)的关系,真正率( F P R = T P T P + F N FPR=\frac{TP}{TP+FN} FPR=TP+FNTP)是实际值的真正正确的值中,预测正确的比率,是召回率的另一个称谓;假正率( F P R = F P T N + F P FPR=\frac{FP}{TN+FP} FPR=TN+FPFP)等于1-特异度(特异度 T N R = T N T N + F P TNR=\frac{TN}{TN+FP} TNR=TN+FPTN);


  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天然玩家

坚持才能做到极致

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值