2.8mnist手写数字识别之模型保存与恢复精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)
目录
2.8mnist手写数字识别之模型保存与恢复精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)
模型加载及恢复训练
在之前的章节已经向读者介绍了将训练好的模型保存到磁盘文件的方法。应用程序可以随时加载模型,完成预测任务。但是在日常训练工作中,我们会遇到一些突发情况,导致训练过程主动或被动的中断。如果训练一个模型需要花费几天的时间,中断后从初始状态重新训练是不可接受的。
万幸的是,飞桨支持从上一次保存状态开始继续训练,只要我们随时保存训练过程中的模型状态,就不用从初始状态重新训练。
下面介绍恢复训练的实现方法,依然使用手写数字识别的案例,网络定义的部分保持不变。
import os
import random
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
import numpy as np
from PIL import Image
import gzip
import json
# 定义数据集读取器
def load_data(mode='train'):
# 数据文件
datafile = './work/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
data = json.load(gzip.open(datafile))
train_set, val_set, eval_set = data
# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS
IMG_ROWS = 28
IMG_COLS = 28
if mode == 'train':
imgs = train_set[0]
labels = train_set[1]
elif mode == 'valid':
imgs = val_set[0]
labels = val_set[1]
elif mode == 'eval':
imgs = eval_set[0]
labels = eval_set[1]
imgs_length = len(imgs)
assert len(imgs) == len(labels), \
"length of train_imgs({}) should be the same as train_labels({})".format(
len(imgs), len(labels))
index_list = list(range(imgs_length))
# 读入数据时用到的batchsize
BATCHSIZE = 100
# 定义数据生成器
def data_generator():
if mode == 'train':
random.shuffle(index_list)
imgs_list = []
labels_list = []
for i in index_list:
img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
imgs_list.append(img)
labels_list.append(label)
if len(imgs_list) == BATCHSIZE: