- 图保存与加载
- 图保存
with tf.Session() as sess: tf.train.write_graph(sess.graph_def,'./tfmodel','test_pb.pb',as_text=False)
- 图加载
with tf.Session() as sess: with tf.gfile.FastGFile('./tfmodel/test_pb.pb','rb') as f: graph_def=tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def,name='tf.graph')
- 建立图
tf.Graph()
- 获得默认图
tf.get_default_graph()
- 重置默认图
tf.reset_default_graph()
- 图保存
- 模型保存和加载
- 模型保存
# 模型保存 saver = tf.train.Saver() saver.save(sess, os.path.join(ckep_dir, 'mnist_h256_model_{:06d}.ckpt'.format(epoch + 1))) # 模型存储
- 模型恢复
# 模型恢复 ckpt = tf.train.get_checkpoint_state(ckep_dir) if ckpt and ckpt.model_checkpoint_path: # 从已保存的模型中读取参数 saver.restore(sess, ckpt.model_checkpoint_path)
- 断点续训
# 检查日志存放目录 ckpt_dir = "CIFAR10_log/" if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) # 生成saver saver = tf.train.Saver() # 如果有检查点文件,读取最新的检查点文件,恢复变量值 ckpt = tf.train.latest_checkpoint(ckpt_dir) if ckpt != None: saver.restore(sess, ckpt) # 加载所有的参数 else: print("Training from scratch")
- 模型保存
tensorflow中模型训练保存与断点续训
最新推荐文章于 2023-04-14 11:24:13 发布