tensorflow模型ckpt持久化

模型保存

sess = tf.Session()#开启会话
sess.run(tf.global_variables_initializer())#所有变量初始化
saver.save(sess, './model_test/model.ckpt',global_step=300)#迭代次数每300保存一次

保存函数

tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=None)

max_to_keep:默认保存最近5个模型
keep_checkpoint_every_n_hours:隔一段固定的时间保存模型
global_step:每迭代多少次保存一次

如果只想保存指定变量,在创建tf.train.Saver实例时,可以通过将需要保存的变量构造list或者dictionary,传入到Saver中,如saver = tf.train.Saver([w1,w2])表示只保存w1,w2。

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

保存的模型如图:
在这里插入图片描述
.meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等;
ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之前,保存在.ckpt文件中。0.11后,通过两个文件保存,即.data文件和.index文件;checkpoint文件,该文件是个文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model

加载模型(即加载图和参数)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('./model_test/model.ckpt-300.meta')# 加载图结构
    saver.restore(sess,tf.train.latest_checkpoint('./model_test/'))# 加载参数
    print(sess.run('w1:0'))显示模型里w1的值
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值