【人工智能与机器学习】第9周--线性分类

实验环境:Jupyter
目的:完成对手写体Mnist数据集中10个字符 (0-9)的分类识别

MNIST

数据介绍:本章使用MNIST数据集,这是一组由美国高中生和人口调查局员工手写的70000个数字的图片。每张图像都用其代表的数字标记。这个数据集被广为使用,因此也被称作是机器学习领域的“Hello World”:但凡有人想到了一个新的分类算法,都会想看看在MNIST上的执行结果。因此只要是学习机器学习的人,早晚都要面对MNIST。

实验步骤及结果

1.使用sklearn的函数来获取MNIST数据集
from sklearn.datasets import fetch_openml
import numpy as np
import os
# to make this notebook's output stable across runs
np.random.seed(42)
# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
# 为了显示中文
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
2.获取数据
def sort_by_target(mnist):
    reorder_train=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[:60000])]))[:,1]
    reorder_test=np.array(sorted([(target,i) for i, target in enumerate(mnist.target[60000:])]))[:,1]
    mnist.data[:60000]=mnist.data[reorder_train]
    mnist.target[:60000]=mnist.target[reorder_train]
    mnist.data[60000:]=mnist.data[reorder_test+60000]
    mnist.target[60000:]=mnist.target[reorder_test+60000]

代码不多,但是运行需要花些时间,请小伙伴们耐心等待

如果不知道啥时候能运行好,可以加个计时器看看运行时间:

import time
y1 = time.time()
mnist=fetch_openml('mnist_784',version=1,cache=True)
mnist.target=mnist.target.astype(np.int8)
sort_by_target(mnist)
y2 = time.time()
display(y2-y1)

计时器运行结果:
在这里插入图片描述
我的运行时间(处理时间)为27秒,有点久,电脑性能好的小伙伴处理时间会短些哦

这句代码的作用是可以对数据进行排序:

mnist["data"], mnist["target"]

运行结果:
在这里插入图片描述

3.查看维度

接着是查看维度,有以下几种方法:
1️⃣

mnist.data.shape

2️⃣

X,y=mnist["data"],mnist["target"]
X.shape

3️⃣

y.shape
28*28

运行结果如下:
在这里插入图片描述
表示的是有70000张照片,784维

4.显示图片
def plot_digit(data):
    image = data.reshape(28, 28)
    plt.imshow(image, cmap = mpl.cm.binary,
               interpolation="nearest")
    plt.axis("off")
some_digit = X[36000]
plot_digit(X[36000].reshape(28,28))

结果如下:
在这里插入图片描述
如果想查看更多的图片,比如十行十列的图片,可以使用一下代码查看:

# 更好看的图片展示
def plot_digits(instances,images_per_row=10,**options):
    size=28
    # 每一行有一个
    image_pre_row=min(len(instances),images_per_row)
    images=[instances.reshape(size,size) for instances in instances]
#     有几行
    n_rows=(len(instances)-1) // image_pre_row+1
    row_images=[]
    n_empty=n_rows*image_pre_row-len(instances)
    images.append(np.zeros((size,size*n_empty)))
    for row in range(n_rows):
        # 每一次添加一行
        rimages=images[row*image_pre_row:(row+1)*image_pre_row]
        # 对添加的每一行的额图片左右连接
        row_images.append(np.concatenate(rimages,axis=1)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值