TensorFlow模型保存pb或ckpt

Tensorflow的保存分为三种:1. checkpoint模式;2. pb模式;3. saved_model模式。
https://www.zhihu.com/collection/644504409

1 checkpoint模式

1.1 保存

checkpoint模式将网络和变量数据分开保存:

|--checkpoint_dir
|    |--checkpoint
|    |--test-model-550.meta
|    |--test-model-550.data-00000-of-00001
|    |--test-model-550.index

checkpoint_dir就是保存时候指定的路径,路径下会生成4个文件。其中.meta文件(其实就是pb格式文件)用来保存模型结构,.data和.index文件用来保存模型中的各种变量,而checkpoint文件里面记录了最新的checkpoint文件以及其它checkpoint文件列表,在inference时可以通过修改这个文件,指定使用哪个model。

# 只有sess中有变量的值,所以保存模型的操作只能在sess内
checkpoint_dir = "./model_ckpt/"
saver = tf.train.Saver(max_to_keep=1)    # saver 不需要在sess内
with tf.Session() as sess:
    saver.save(sess, checkpoint_dir + "test-model",global_step=i, write_meta_graph=True)

执行之后就可以在checkpoint_dir下面看到前面提到的4个文件了。(这里的max_to_keep是指本次训练在checkpoint_dir这个路径下最多保存多少个模型文件,新模型会覆盖旧模型以节省空间)。

	# Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
    checkpoint_prefix = os.path.join("output/eval_result", "model")
    if not os.path.exists(checkpoint_prefix):
        os.makedirs(checkpoint_prefix)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=config.num_checkpoints)

    # Initialize all variables
    sess.run(tf.global_variables_initializer())

    # 预训练
    saver.restore(sess, '../data/tv.pretrain.ckpt/pretrain_ckpt_841_765_8800/model-8800')  # 加载的ckpt
    print('restore ckpt')
current_step = tf.train.global_step(sess, global_step)
if eval_accuracy > min_accuracy:
    saver.save(sess, checkpoint_prefix, global_step=current_step)  # 保存ckpt
    print("Saved model on {}\n".format(current_step))

1.2 加载

加载ckpt并转换成saved_model模式用于上线部署

(1)restore:恢复ckpt

	sess_config = tf.ConfigProto()
	sess_config.gpu_options.allow_growth = True
	sess = tf.Session(config=sess_config)

    saver = tf.train.Saver(tf.global_variables(), max_to_keep=config.num_checkpoints)

    # Initialize all variables
    sess.run(tf.global_variables_initializer())
    print('ckpt_path:', ckpt_path)
    saver.restore(sess, ckpt_path)  # 恢复数据,ckpt地址

(2)写好feed_dict,然后sess.run
(3)builder.save(),重新保存为saved_model模式

		# 保存pb文件
        if save_pb:
            builder = tf.saved_model.builder.SavedModelBuilder(path+'/model' + str(current_step))
            builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING])
            builder.save()
            print("Saved model on {}\n".format(current_step))

2 pb模式:convert_variables_to_constants

https://www.jianshu.com/p/091415b114e2
https://www.tensorflow.org/api_docs/python/tf/compat/v1/graph_util/convert_variables_to_constants

2.1 保存

pb模式保存的模型,只有在目标路径pb_dir = "./model_pb/“下孤孤单单的一个文件"test-model.pb”,这也是它相比于其他几种方式的优势,简单明了。

# 只有sess中有变量的值,所以保存模型的操作只能在sess内
pb_dir = "./model_pb/"
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph_def = tf.get_default_graph().as_graph_def()
    # 这里是指定要冻结并保存到pb模型中的变量
    var_list = ["input", "label", "beta", "bias", "output"]   # 如果有name_scope,要写全名,如:"name_scope/beta" 
    constant_graph = tf.graph_util.convert_variables_to_constants(sess, graph_def, var_list)
    with tf.gfile.FastGFile(pb_dir + "test-model.pb", mode='wb') as f:
        f.write(constant_graph.SerializeToString())

其实pb模式本质上就是把变量先冻结成常数,然后保存到图结构中。这样就可以直接加载图结构和“参数”了。

# pb模式:convert_variables_to_constants
        if pb_mode:
            pb_dir = './only_pb_1223/'  # 提前新建
            graph_def = tf.get_default_graph().as_graph_def()
            var_list = ["audio_input_x", "query_input_x", ..., "output/probability", ...]
            constant_graph = tf.graph_util.convert_variables_to_constants(sess, graph_def, var_list)
            with tf.gfile.FastGFile(pb_dir + "test-model.pb", mode='wb') as f:
                f.write(constant_graph.SerializeToString())

convert_variables_to_constants:把变量转换成常量
参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层采用了name_scope,所以我们在probability之前需要加上output/)

2.2 加载

pb模式的加载旧没那么复杂,因为他的网络结构和数据是存在一起的。

import numpy as np
import tensorflow as tf

# 直接从pb获取tensor
pb_dir = "./model_pb/"
with tf.gfile.FastGFile(pb_dir + "test-model.pb", "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())    # 从pb文件中导入信息
    # 从网络中通过tensor的name获取为变量
    X, pred = tf.import_graph_def(graph_def, return_elements=["input:0", "output:0"])

现在已经有了X和pred,下面来跑一个pred

# 假设这是一个batch
feed_X = np.ones((8,size)).astype(np.float32)
feed_y = np.ones((8,1)).astype(np.float32)
# 跑一下 pred
with tf.Session() as sess:
    # sess.run(tf.global_variables_initializer())
    print(sess.run(pred, feed_dict={X:feed_X}))

从pb中获取进来的“变量”就可以直接用。为什么我要给变量两个字打上引号呢?

  • 因为在pb模型里保存的其实是常量了,取消注释sess.run(tf.global_variables_initializer())后,多次运行的结果还是一样的。此时的“beta:0”和"bias:0"已经不再是variable,而是constant。
  • 这带来一个好处:读取模型中的tensor可以在Session外进行。相比之下checkpoint只能在Session内读取模型,对Fine-tune来说就比较麻烦。
    with tf.gfile.FastGFile(saved_model_dir + "test-model-noCnnFeatureNoDropout.pb", "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())  # 从pb文件中导入信息
        # 从网络中通过tensor的name获取为变量
        audio_input_x, query_input_x, probability
            = tf.import_graph_def(graph_def, return_elements=["audio_input_x:0","query_input_x:0", "output/probability:0"])

打印维度

print("audio_input_x:0", audio_input_x)
print("query_input_x:0", query_input_x)

查看运算节点:TensorFlow查看输入节点和输出节点名称

tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
	print(tensor_name, '\n')
  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
要将训练好的 TensorFlow 模型保存为 .pb 文件,您可以按照以下步骤进行操作: 1. 定义模型结构:在保存模型之前,您需要定义模型的结构,包括输入和输出节点的名称、形状和数据类型。您可以使用 TensorFlow 的高级 API(如 Keras)或自定义模型来定义模型结构。 2. 加载模型权重:将训练好的模型权重加载到定义的模型结构中。这可以通过加载已保存模型权重文件(如 .h5、.ckpt 等)或通过重新训练模型来实现。 3. 创建 SavedModel:使用 TensorFlow 的 `tf.saved_model.save` 函数将模型保存为 SavedModel 格式。SavedModel 是 TensorFlow 的一种标准模型保存格式,可以包含模型的计算图和变量值。 ```python import tensorflow as tf # 定义和加载模型权重 model = ... # 定义模型结构 model.load_weights('model_weights.h5') # 加载模型权重 # 保存为 SavedModel 格式 tf.saved_model.save(model, 'saved_model') ``` 这将会在指定路径下创建一个名为 `saved_model` 的文件夹,其中包含了模型的计算图和变量值。 4. 导出为 .pb 文件:从 SavedModel 中导出所需的 .pb 文件。可以使用 TensorFlow 的 `tf.compat.v1.graph_util.convert_variables_to_constants` 函数将 SavedModel 的计算图和变量值转换为常量,并保存为 .pb 文件。 ```python from tensorflow.python.framework import graph_util # 加载 SavedModel saved_model_dir = 'saved_model' saved_model = tf.saved_model.load(saved_model_dir) # 将 SavedModel 转换为 .pb 文件 output_pb_file = 'model.pb' graph_def = graph_util.convert_variables_to_constants( saved_model.sess, saved_model.sess.graph_def, ['output_node_name']) with tf.io.gfile.GFile(output_pb_file, 'wb') as f: f.write(graph_def.SerializeToString()) ``` 将上述代码中的 `'output_node_name'` 替换为模型输出节点的名称。 现在,您应该已经成功将训练好的 TensorFlow 模型保存为 .pb 文件。请注意,这只是一个基本示例,具体的实现细节可能因您的模型结构和需求而有所不同。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值