深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。
方法一:载入模型时,不必指定迭代次数,一般默认最新
# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
# 开启会话
with tf.Session() as sess:
# saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
print("Model restored...")
else:
print('No Model')
方法二:载入时,指定想要载入模型的迭代次数
需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。