TensorFlow-ckpt-模型的保存和调用

更多内容:
https://blog.csdn.net/lwplwf/article/details/62419087

1、保存模型

import tensorflow as tf
import numpy as np

**xxxxx无关紧要代码**

saver = tf.train.Saver(max_to_keep=3)  #关键1  默认max_to_keep=5
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) 
    
    **网络训练部分**  #关键2
    
    save_path = saver.save(sess, r"F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt")  #关键3
    print("Save to path: ", save_path)

运行结果:
Save to path: F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt

2、调用模型

import tensorflow as tf
import numpy as np

**xxxxx无关紧要代码**

saver = tf.train.Saver()  #关键4
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) 
    saver.restore(sess, r"F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt")   #关键5
      
    **删除网络训练部分,这些代码不需要了**  #关键6
        

运行结果:
INFO:tensorflow:Restoring parameters from F:\laboratorie_essays\cbc\WGAN_code\hhhh.ckpt
weights: [[1. 2. 3.]
[3. 4. 5.]]
biases: [[1. 2. 3.]]

3、保存的文件名称含义:
https://blog.csdn.net/sinat_36618660/article/details/98665482

checkpoint: 检查点,记录保存了哪些checkpoint;
.ckpt.data: 以字典的形式保存权重偏置项等训练参数;
.ckpt.index: 参数索引;
.ckpt.meta: 图结构以及图中参数数据。

4、查看checkpoint文件中保存的变量名称:

https://www.cnblogs.com/weizhen/p/8451514.html

import tensorflow as tf
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensors in checkpoint file
chkp.print_tensors_in_checkpoint_file("tmp/model.ckpt", tensor_name=None, all_tensors=True)
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
调用保存好的 TensorFlow 模型,可以使用以下步骤: 1. 定义模型结构和训练过程。 2. 创建一个 `tf.train.Saver` 对象,用于保存和恢复模型。 3. 在训练结束后,调用 `saver.save()` 方法保存模型。 4. 在测试或预测过程中,使用 `tf.train.import_meta_graph()` 方法加载模型的图结构。 5. 创建一个 `tf.Session` 对象,并使用 `saver.restore()` 方法恢复模型的参数。 6. 在 `Session` 中执行模型的前向传播操作,获取预测结果。 以下是一个简单的示例代码,展示如何加载保存好的 TensorFlow 模型: ```python import tensorflow as tf # 定义模型结构和训练过程 x = tf.placeholder(tf.float32, [None, 784], name='x') y = tf.placeholder(tf.float32, [None, 10], name='y') w = tf.Variable(tf.zeros([784, 10]), name='w') b = tf.Variable(tf.zeros([10]), name='b') logits = tf.matmul(x, w) + b loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y)) train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss) # 创建 Saver 对象 saver = tf.train.Saver() with tf.Session() as sess: # 恢复模型的图结构 saver = tf.train.import_meta_graph('model.ckpt.meta') # 加载模型的参数 saver.restore(sess, 'model.ckpt') # 获取模型的输入和输出张量 graph = tf.get_default_graph() x = graph.get_tensor_by_name('x:0') y = graph.get_tensor_by_name('y:0') logits = graph.get_tensor_by_name('add:0') # 执行模型的前向传播操作 predictions = tf.argmax(logits, axis=1) test_data = ... test_labels = ... feed_dict = {x: test_data, y: test_labels} results = sess.run(predictions, feed_dict=feed_dict) ``` 在上述代码中,我们首先定义了一个简单的模型结构和训练过程,并使用 `tf.train.Saver` 对象保存模型。在测试或预测过程中,我们使用 `tf.train.import_meta_graph()` 方法加载了模型的图结构,并使用 `saver.restore()` 方法恢复了模型的参数。然后,我们通过 `graph.get_tensor_by_name()` 方法获取了模型的输入和输出张量,并执行了模型的前向传播操作,获取了预测结果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值