以下内容基于TensorFlow1.2
- checkpoints保存和恢复模型
-
保存模型
- 定义要输出的tensor
output_name = ... # 如果是预测,output_tensor就是最后得出预测结果的tensor(预测的y) output_tensor = ... tf.add_to_collection(output_name, output_tensor)
- 保存checkpoint
checkpoint_dir = ... global_step = ... with tf.Session() as sess: # max_to_keep 最多保存几个checkpoint的文件 saver = tf.train.Saver.save(tf.global_variables(), max_to_keep=5) saver.save(sess, checkpoint_dir, global_step=global_step)
-
恢复模型
# 必须同保存模型时定义的output_name的值一致 output_name = ... with tf.Session(
-