【机器学习8】模型评估+识别Mnist数据集的字符

整体要求:

1、阅读“机器学习”(周志华著)第二章“模型评估与选择”,理解“查准率”、“查全率”、“F1-Score”、“ROC”、“混淆矩阵”的定义。

2、学习“机器学习实战”第三章-分类器,Jupyter编程完成对手写体Mnist数据集中10个字符 (0-9)的分类识别。
(其中掺杂本人理解,不喜勿喷,欢迎纠错~)

1.查准率与查全率

在信息检索、Web搜索中,我们经常会关心“检索出的信息中有多少比例是用户感兴趣的”,所以引出“ 查准率”(precision)与 “查全率”(recall)作为此类需求的性能度量。
一件事的真假,和机器预测这件事的正反,两者组成真正、真反、假正、假反四种情况,四种情况加起来=样例总例,
分类结果的混淆矩阵如下:
在这里插入图片描述
查准率P与查全率R分别定义为:

在这里插入图片描述
P高,R就低,反之亦然~
给P-R画个图就是:
在这里插入图片描述

  • 图中有三个机器的PR图,因为A包住了B,所以A机器性能好一些~
  • 但是肉眼观察的包住并不准确,所以需要平衡点来帮助我们判断~~
  • 平衡点就是P=R的时候的点,平衡点英文叫BEP

2.F1

用平衡点来衡量性能过于简单~
于是就有了F1:
在这里插入图片描述
但有时候需要查全率更高一些,有时候需要查准率重要一些,于是推出了β!
在这里插入图片描述

β> 1时查全率有更大影响; β < 1时查准率有更大影响.β=1,就是F1。

3.ROC

中文名:“受试者工作特征”曲线
就是和P-R类似~
就是TPR-FPR的曲线~
公式中的参数 就是P-R用到的参数,去上面看吧~
在这里插入图片描述
根据公式画出一个例子:
图中AUC是啥??就是包住的面积啦~
在进行多个机器比较时,面积越大越好呗~
在这里插入图片描述

4.混淆矩阵

去上面看1的表,看到表里面的中文,应该好理解~

5.Jupyter编程完成对手写体Mnist数据集中10个字符 (0-9)的分类识别

第一步:介绍Mnist
MNIST数据集,这是一组由美国高中生和人口调查局员工手写的70000个数字的图片。每张图像都用其代表的数字标记。

第二步:使用sklearn的函数来获取MNIST数据集

from sklearn.datasets import fetch_openml
import numpy as np
import time
import os
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
# 为了显示中文
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False

def sort_by_target(mnist):
    reorder_train=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[:60000])]))[:,1]
    reorder_test=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[60000:])]))[:,1]
    mnist.data[:60000]=mnist.data[reorder_train]
    mnist.target[:60000]=mnist.target[reorder_train]
    mnist.data[60000:]=mnist.data[reorder_test+60000]
    mnist.target[60000:]=mnist.target[reorder_test+60000]
    
mnist=fetch_openml('mnist_784',version=1,cache=True)
mnist.target=mnist.target.astype(np.int8)
sort_by_target(mnist)
X,y=mnist["data"],mnist["target"]

784指的像素,28x28=784
于是我们就得到了minist数据集

第三步:随便打印一个数据

# 展示图片
def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap = mpl.cm.binary,
               interpolation="nearest")
    plt.axis("off")
some_digit = X[56000]
plot_digit(X[56000].reshape(28,28))

结果如下:

在这里插入图片描述

第四步:打印更好看的数据出来(其实没啥用~)

# 更好看的图片展示
def plot_digits(instances,images_per_row=10,**options):
    size=28
    # 每一行有一个
    image_pre_row=min(len(instances),images_per_row)
    images=[instances.reshape(size,size) for instances in instances]
#     有几行
    n_rows=(len(instances)-1) // image_pre_row+1
    row_images=[]
    n_empty=n_rows*image_pre_row-len(instances)
    images.append(np.zeros((size,size*n_empty)))
    for row in range(n_rows):
        # 每一次添加一行
        rimages=images[row*image_pre_row:(row+1)*image_pre_row]
        # 对添加的每一行的额图片左右连接
        row_images.append(np.concatenate(rimages,axis=1))
    # 对添加的每一列图片 上下连接
    image=np.concatenate(row_images,axis=0)
    plt.imshow(image,cmap=mpl.cm.binary,**options)
    plt.axis("off")
plt.figure(figsize=(9,9))
example_images=np.r_[X[:12000:600],X[13000:30600:600],X[30600:60000:590]]
plot_digits(example_images,images_per_row=10)
plt.show()

效果如下:

在这里插入图片描述

第五步:创建一个测试集(不是很清楚这个是啥,还要洗牌)

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
import numpy as np

shuffer_index=np.random.permutation(60000)
X_train,y_train=X_train[shuffer_index],y_train[shuffer_index]

第六步:训练一个二分类器
后面再更新~~

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值