(基于飞浆的)手写数字识别模型——数据处理
1、背景知识
1.1 手写数字识别任务
我们知道,分类问题跟回归问题有明显的区别,回归问题是连续的数值,而分类问题是离散的类别,比如将性别分为[男,女],将图片分为[猫,狗,兔]等。数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别,手写数字识别是一个典型的图像分类问题。
1.2 MNIST数据集
MNIST手写数字数据集包含60000张训练图片和10000张测试图片,这些图片是从0~9的手写数字,分辨率为28*28,大致是下面这个样子:
2. 代码复现
2.1 类库导入
安装飞浆的环境,请参考https://www.paddlepaddle.org.cn/install/quick。在数据处理前,需要通过以下代码加载飞桨与手写数字识别模型相关的类库,实现方法如下:
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
import numpy as np
import os
from PIL import Image
2.2 读入数据并划分数据集
在实际应用中,保存到本地的数据存储格式多种多样,如MNIST数据集以json格式存储在本地,data(数据集)中包含三个元素的列表:train_set、val_set、 test_set。
在本地目录下读取MNIST数据,并拆分成训练集、验证集和测试集,实现方法如下所示。
# 声明数据集文件位置,work目录下的mnist文件
datafile = './work/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
# 加载json数据文件
data = json.load(gzip.open(datafile))
print('mnist dataset load done')
# 读取到的数据区分训练集,验证集,测试集
train_set, val_set, eval_set