保存模型
在反向传播中,如果想每隔一定的轮数将模型保存下来,可以用下面的方法。
1)首先,实例化saver对象
saver = tf.train.Saver()
2)在Session会话中,每隔一段轮数,进行模型的保存
with tf.Session() as sess:
for i in range(STEPS):
if i%轮数 == 0:
saver.save(sess, os.path.join(MODEL_SAVER_PATH, MODEL_NAME), global_step=global_step)
加载模型
先判断是否有模型存在,如果有,则恢复到当前的会话中。
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(存储路径)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.mdoel_checkpoint_path)
滑动平均
如果训练的过程中参数使用滑动平均,则每个参数的滑动平均值,也会保存在模型当中。
ema = tf.train.ExpenentialMovingAverage(滑动平均模型)
ema_restore = ema.vaiables_to_restore()
saver = tf.train.Saver(ema_restore)