TensorFlow入门 |(三)MNIST手写数字识别

【更多精彩内容请各位读者移步公众号*“econe认视界**”探索】*

相关阅读:
TensorFlow入门 |(零)环境安装
TensorFlow入门 |(一)基本知识详解
TensorFlow入门 |(二)房价预测模型

(一)MNIST手写体数据集介绍

MNIST手写数字识别算是计算机视觉入门的一个项目。也是传说中机器学习入门的“Hello World”。小编写这篇文章的时候也算是一个新手,分享下自己的学习成果,和大家一起学习进步。

MNIST数据集包含各种(0~9)手写数字:在这里插入图片描述
数据集官网:http://yann.lecun.com/exdb/mnist/ ,数据集包含60000个训练样本和10000个测试样本。网站提供的文件如下图,有相应的训练集和测试集图片及对应标签。详细描述信息可查看官网阅读。
在这里插入图片描述
数据集里的每张数字图片都是28×28(=784)的像素组成,且都是256阶(0~255)的灰度图
在这里插入图片描述
为了加速训练,我们需要把每张数字图片做规范化处理,也就是把每个像素压缩到0~1的范围(此时也叫像素的强度值)。
在这里插入图片描述
用Keras加载数据集:
tf.keras.datasets.mnist.load_data(path='mnist.npz')
Arguments:
path:本地缓存MNIST数据集(MNIST.npz)的相对路径(~/.keras/datasets)
Returns:
Tuple of Numpy arrays: ‘(x_train, y_train), (x_text, y_text)’.
详见 mnist.load_data API文档
若在加载数据集过程中报错,应该是数据集路径有问题。具体解决办法可参考我的博客 MNIST数据集加载的问题&方法

具体代码如下:

from keras.darasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

输出结果如下:

(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)

如前所述,60000个训练样本图片,大小为28×28像素,相应有60000个标签。测试集同理。

下面可视化数据集前几张图片看下。

import matplotlib.pyplot as plt
fig = plt.figure()           # 创建一张图
# 显示前十五张灰度图片
for i in range(15):
	plt.subplot(3, 5, i+1)   # 用3行5列形式展示
	plt.tight_layout()       # 自动适配子图位置,看起来不拥挤
	plt.imshow(x_train[i], cmap = 'Greys')     # 用灰色显示图像灰度值
	plt.xticks([])           # 删除x轴标记,否则会自动表上坐标
	plt.yticks([])

在这里插入图片描述

(二)MNIST softmax网络

我们要解决识别10个数字这样多分类的问题,最直接简单的就是应用softmax网络。关于softmax网络的具体理论在这里不多赘述,默认各位读者都知道,不知道的话网上搜一下看看,很容易理解,这里小编主要来说说如何应用。

将手写数字图片作为 [784] 的一维向量输入;中间定义2层512个神经元的隐藏层,最后定义1层10个神经元的全连接层,用于输出10个不同类别的“概率”
在这里插入图片描述

数据处理
归一化

对手写数字图像维度转换[28, 28] -->[784],再归一化处理

# 数据规范化
X_train = x_train.reshape(60000, 784)
X_test = x_test.reshape(10000, 784)
# 转换层浮点数,否则整形归一化没有意义
X_train = x_train.astype('float32')
X_test = x_test.astype('float32')
X_train = x_train / 255
X_test = x_test / 255

统计训练数据的标签数量并做统计

label, count = np.unique(y_train, return_counts=True)
print(label, count)
[0 1 2 3 4 5 6 7 8 9] [5923 6742 5958 6131 5842 5421 5918 6265 5851 5949]

fig = plt.figure()
plt.bar(label, count, width = 0.7, align = 'center')   # 用柱状图可视化,横坐标label,纵坐标count,关于中心对称
plt.title('Label Distribution')
plt.xlabel('Label')
plt.ylabel('Count')
plt.xticks(label)   # 把label分布在x轴上
plt.ylim(0, 7500)   # y轴取值区间

# 在柱状图上方显示每个标签的数量,以便我们观察
for a,b in zip(label, count):
    plt.text(a,b,'%d'%b, ha = 'center', va = 'bottom', fontsize = 10 )   
    
  • 4
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值