前言
趁着国庆尾巴,复习了Paddle框架进行深度学习实践:手写数字识别,这里分享下模型实现。
1. Paddle手写数字识别过程
这里给大家分享下手写数字识别的主要步骤:
- 定义数据处理过程:定义MnistDataset类,继承自paddle.io.Dataset实现模型输入数据处理,与paddle.io.DataLoader配合使用,实现数据异步加载,提高模型训练速度;
- 定义深度学习模型:这里使用简单的多个卷积层、ReLU激活函数,池化层来提取图像特征,使用全连接层,Softmax实现图像分类;
- 训练配置:使用随机梯度下降SGD来优化模型参数,使用交叉熵作为分类损失函数。
- 训练过程:前向计算,损失计算,模型参数更新三个过程循环进行,直到达到优化目标,即损失值足够小;
- 保存模型:保存上述训练模型参数,以供推理阶段加载使用。
2. Paddle手写数字识别训练与推理过程实现
# 导入飞桨和其他相关库
import paddle
from paddle.nn import Conv2D, MaxPool2D, Linear
import paddle.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import gzip
import os
import json
import random
from PIL import Image
# 创建一个类MnistDataset, 继承paddle.io.Dataset,配合DataLoader实现数据异步加载
class MnistDataset(paddle.io.Dataset):
def __init__(self, mode='train'):
datafile = './work/mnist.json.gz'
data = json.load(gzip.open(datafile))
# 划分数据集为训练集、验证集和测试集
train_set, val_set, test_set = data[:3]
# 图片高度和宽度
self.IMG_ROWS, self.IMG_COLS = 28, 28
if mode == 'train':
# 训练数据集
imgs, labels = train_set[:2]
elif mode == 'valid':
imgs, labels = val_set[:2]
elif mode == 'eval':
imgs, labels = test_set[:2]
else:
raise Exception("mode can only be one of [train, valid, eval]")
# 校验数据
imgs_length = len(imgs)
assert len(imgs) == len(labels), \
"length of train_imgs({}) should be the same with train_labels({})".format(
len(imgs)