断点恢复
在做深度学习训练的时候,由于训练时间比较长,迭代次数比较多,经常会出现无法一次完成train的情况,那么这个时候我们需要用到tensorflow中的断点恢复。不多说直接上例子
#step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])通过文件名得到模型保存时迭代的轮数
#tf.train.get_checkpoint_state函数会通过checkpoint文件自动找到目录中最新模型的文件名
ckpt = tf.train.get_checkpoint_state(CKPT_PATH) if ckpt and ckpt.model_checkpoint_path: #加载模型 saver.restore(sess,ckpt.model_checkpoint_path)
存model的时候,当前step的值被赋予到global_step, 所以 在train的时候要把 global_step的值赋给step,这样才可以从断点处计算。