第二章主要通过手写数字识别的案例来介绍深度学习
目录
一. 使用飞桨完成手写数字识别模型
手写数字识别任务
数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别。手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别、手写邮政编码识别等领域,大大缩短了业务处理时间,提升了工作效率和质量。
在处理手写邮政编码的简单图像分类任务时,可以使用基于MNIST数据集的手写数字识别模型。MNIST是深度学习领域标准、易用的成熟数据集,包含60000条训练样本和10000条测试样本。
- 任务输入:一系列手写数字图片,其中每张图片都是28x28的像素矩阵。
- 任务输出:经过了大小归一化和居中处理,输出对应的0~9的数字标签。
MNIST数据集
MNIST数据集是从NIST的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。Yann LeCun等人从SD-1和SD-3中各取一半数据作为MNIST训练集和测试集,其中训练集来自250位不同的标注员,且训练集和测试集的标注员完全不同。
MNIST数据集的发布,吸引了大量科学家训练模型。1998年,LeCun分别用单层线性分类器、多层感知器(Multilayer Perceptron, MLP)和多层卷积神经网络LeNet进行实验,使得测试集的误差不断下降(从12%下降到0.7%)。在研究过程中,LeCun提出了卷积神经网络(Convolutional Neural Network,CNN),大幅度地提高了手写字符的识别能力,也因此成为了深度学习领域的奠基人之一。
如今在深度学习领域,卷积神经网络占据了至关重要的地位,从最早LeCun提出的简单LeNet,到如今ImageNet大赛上的优胜模型VGGNet、GoogLeNet、ResNet等,人们在图像分类领域,利用卷积神经网络得到了一系列惊人的结果。
构建手写数字识别的神经网络模型
代码比较
模型均为数据处理、定义网络结构和训练过程三个部分
二. 通过极简方案快速构建手写数字识别模型
前提条件
加载与手写数字识别相关的库
#加载飞桨和相关类库
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
from PIL import Image
数据处理
通过paddle.dataset.mnist.train()函数设置数据读取器,batch_size设置为8,即一个批次有8张图片和8个标签,代码如下所示。
# 如果~/.cache/paddle/dataset/mnist/目录下没有MNIST数据,API会自动将MINST数据下载到该文件夹下
# 设置数据读取器,读取MNIST数据训练集
trainset = paddle.dataset.mnist.train()
# 包装数据读取器,每次读取的数据数量设置为batch_size=8
train_reader = paddle.batch(trainset, batch_size=8)
paddle.batch函数将MNIST数据集拆分成多个批次,通过如下代码读取第一个批次的数据内容,观察打印结果。
# 以迭代的形式读取数据
for batch_id, data in enumerate(train_reader()):
# 获得图像数据,并转为float32类型的数组
img_data = np.array([x[0] for x in data]).astype('float32')
# 获得图像标签数据,并转为float32类型的数组
label_data = np.array([x[1] for x in data]).astype('float32')
# 打印数据形状
print("图像数据形状和对应数据为:", img_data.shape, img_data[0])
print("图像标签形状和对应数据为:", label_data.shape, label_data[0])
break
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(label_data[0]))
# 显示第一batch的第一个图像
import matplotlib.pyplot as plt
img = np.array(img_data[0]+1)*127.5
img = np.reshape(img, [28, 28]).astype(np.uint8)
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
图像数据形状和对应数据为: (8, 784) [-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -0.9764706 -0.85882354 -0.85882354 -0.85882354
-0.01176471 0.06666672 0.37254906 -0.79607844 0.30196083 1.
0.9372549 -0.00392157 -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -0.7647059 -0.7176471 -0.26274508 0.20784318
0.33333337 0.9843137 0.9843137 0.9843137 0.9843137 0.9843137
0.7647059 0.34901965 0.9843137 0.8980392 0.5294118 -0.4980392
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -0.6156863
0.8666667 0.9843137 0.9843137 0.9843137 0.9843137 0.9843137
0.9843137 0.9843137 0.9843137 0.96862745 -0.27058822 -0.35686272
-0.35686272 -0.56078434 -0.69411767 -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -0.85882354 0.7176471 0.9843137
0.9843137 0.9843137 0.9843137 0.9843137 0.5529412 0.427451
0.9372549 0.8901961 -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -0.372549 0.22352946 -0.1607843 0.9843137
0.9843137 0.60784316 -0.9137255 -1. -0.6627451 0.20784318
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -0.8901961 -0.99215686 0.20784318 0.9843137 -0.29411763
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. 0.09019613 0.9843137 0.4901961 -0.9843137 -1.
-1