**
Tenosrflow 训练模型保存
**
1.保存和载入模型
(1)保存模型
saver=tf.train.Saver()
with tf.Session as sess:
sess.run(init)
#...训练
saver.save(sess,"save_path/file_name")#将file_name换成保存的文件名,例如“linermode.cpkt”
(2)载入模型
模型保存后,通过saver的restore()函数调用
saver=tf.train.Saver()
with tf.Session as sess:
saver.restore(sess,"save_path/file_name"
2 tf.train.Saver()介绍
tf.train.Saver(var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=2,
pad_step_number=False,
save_relative_paths=False,
filename=None)
3.保存检查点
在训练过程中,保存模型
saverdir="log/"#模型保存路径
saver=tf.train.Saver()
#保存检查点
with tf.Session() as sess1:
saver.save(sess1,saverdir+"linermodel.cpkt",global_step=epoch)#epoch是迭代次序
#载入检查点,重新开启一个Session
load_epoch=18
with tf.Session() as sess2:
saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))
4.使用MonitoredTrainningSession函数来保存检查点
在大型的数据集训练时,一般都是每隔固定的时间保存一次模型。
import tensorflow as tf
global_step=tf.train.get_or_create_global_step()
with tf.train.MonitoredTrainingSession(checkpoint_dir="log/checkpoints",save_checkpoint_secs=2) as sess:
...
注意:如果不设定save_checkpoint_secs参数,系统默认10分钟保存一次模型。
这种方法保存模型必须要先定义global_step变量,否则会报错