Tenosrflow 训练模型保存

**

Tenosrflow 训练模型保存

**

1.保存和载入模型
(1)保存模型

saver=tf.train.Saver()
with tf.Session as sess:
    sess.run(init)
    #...训练
    saver.save(sess,"save_path/file_name")#将file_name换成保存的文件名,例如“linermode.cpkt”

(2)载入模型
模型保存后,通过saver的restore()函数调用

saver=tf.train.Saver()
with tf.Session as sess:
	saver.restore(sess,"save_path/file_name"

2 tf.train.Saver()介绍
tf.train.Saver(var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=2,
pad_step_number=False,
save_relative_paths=False,
filename=None)
3.保存检查点
在训练过程中,保存模型

saverdir="log/"#模型保存路径
saver=tf.train.Saver()
#保存检查点
with tf.Session() as sess1:
	saver.save(sess1,saverdir+"linermodel.cpkt",global_step=epoch)#epoch是迭代次序
#载入检查点,重新开启一个Session
load_epoch=18
with tf.Session() as sess2:
	saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))

4.使用MonitoredTrainningSession函数来保存检查点
在大型的数据集训练时,一般都是每隔固定的时间保存一次模型。

import tensorflow as tf
global_step=tf.train.get_or_create_global_step()
with tf.train.MonitoredTrainingSession(checkpoint_dir="log/checkpoints",save_checkpoint_secs=2) as sess:
	...

注意:如果不设定save_checkpoint_secs参数,系统默认10分钟保存一次模型。
这种方法保存模型必须要先定义global_step变量,否则会报错
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值