深度学习课程 DAY 6 - 图像分类问题:手写数字识别案例(八)
Chapter 3 图像分类问题
3.9 模型保存和恢复训练
(1)模型加载
在之前的章节已经向读者介绍了将训练好的模型保存到磁盘文件的方法。应用程序可以随时加载模型,完成预测任务。但是在日常训练工作中,我们会遇到一些突发情况,导致训练过程主动或被动的中断。如果训练一个模型需要花费几天的时间,中断后从初始状态重新训练是不可接受的。从上一次保存状态开始继续训练,只要我们随时保存训练过程中的模型状态,就不用从初始状态重新训练。
但模型加载和恢复比模型预测需要保存的内容要更多。预测时只需要保存模型的参数,而恢复训练时不仅要保存模型之前的参数状态,还有保存优化器内部的状态。例如Adagrad的内部存在状态,根据现有梯度动态调整。
下面介绍恢复训练的实现方法,依然使用手写数字识别的案例,网络定义的部分保持不变。
import os
import random
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
import numpy as np
from PIL import Image
import gzip
import json
# 定义数据集读取器
def load_data(mode='train'):
# 数据文件
datafile = './work/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
data = json.load(gzip.open(datafile))
train_set, val_set, eval_set = data
# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS
IMG_ROWS = 28
IMG_COLS = 28
if mode == 'train':
imgs = train_set[0]
labels = train_set[1]
elif mode == 'valid':
imgs = val_set[0]
labels = val_set[1]
elif mode == 'eval':
imgs = eval_set[0]
labels = eval_set[1]
imgs_length = len(imgs)
assert len(imgs) == len(labels), \
"length of train_imgs({}) should be the same as train_labels({})".format(
len(imgs), len(labels))
index_list = list(range(imgs_length))
# 读入数据时用到的batchsize
BATCHSIZE = 100
# 定义数据生成器
def data_generator():
if mode == 'train':
random.shuffle(index_list)
imgs_list = []
labels_list = []
for i in index_list:
img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
imgs_list.append(img)
labels_list.append(label)
if len(imgs_list) == BATCHSIZE:
yield np.array(imgs_list), np.array(labels_list)
imgs_list = []
labels_list = []
# 如果剩余数据的数目小于BATCHSIZE,
# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch
if len(imgs_list) > 0:
yield np.array(imgs_list), np.array(labels_list)
return data_generator
#调用加载数据的函数
train_loader = load_data('train')
# 定义模型结构
class MNIST(fluid.dygraph.