【tensorflow】TF1.x保存与读取.pb模型写法介绍

【tensorflow】TF1.x保存与读取.pb模型写法介绍

  由于TF里面的概念比较接地气,所以用tf1.x保存.pb模型时总是怕有什么操作漏掉了,会造成保存的模型是缺少变量数据或者没有保存图,所以先明确一下:用TF1.x保存模型时只需要保存模型的输入输出的变量(多输入就保存多个),不需要保存中间的变量;用TF1.x加载模型时只需要加载保存的模型,然后读一下输入输出变量(多输入就读多个),不需要初始化(反而会重置掉变量的值)。

举例:模型定义如下

# 定义模型
with tf.name_scope("Model"):
    """MLP"""
    # 13个连续特征数据(13列)
    x = tf.placeholder(tf.float32, [None,13], name='X') 
    # 正则化
    x_norm = tf.layers.batch_normalization(inputs=x)
    # 定义一层Dense
    dense_1 = tf.layers.Dense(64, activation="relu")(x_norm)
    
    """EMBED"""
    # 离散输入
    y = tf.placeholder(tf.int32, [None,2], name='Y')
    # 创建嵌入矩阵变量
    embedding_matrix = tf.Variable(tf.random_uniform([len(vocab_dict) + 1, 8], -1.0, 1.0))
    # 使用tf.nn.embedding_lookup函数获取嵌入向量
    embeddings = tf.nn.embedding_lookup(embedding_matrix, y)
    # 创建 LSTM 层
    lstm_cell = tf.nn.rnn_cell.LSTMCell(64)
    # 初始化 LSTM 单元状态
    initial_state = lstm_cell.zero_state(tf.shape(embeddings)[0], tf.float32)
    # 将输入数据传递给 LSTM 层
    lstm_out, _ = tf.nn.dynamic_rnn(lstm_cell, embeddings, initial_state=initial_state)
    # 定义一层Dense
    dense_2 = tf.layers.Dense(64, activation="relu")(lstm_out[:, -1, :])
    
    """MERGE"""
    combined = tf.concat([dense_1, dense_2], axis = -1)
    pred = tf.layers.Dense(2, activation="relu")(combined)
    pred = tf.layers.Dense(1, activation="linear", name='P')(pred)
    
    z = tf.placeholder(tf.float32, [None, 1], name='Z')

  虽然写这么多,但是上面模型的输入只有xyz,输出只有pred。所以我们保存、加载模型时,只用考虑这几个变量就可以。

模型保存代码

import tensorflow as tf
from tensorflow import saved_model as sm


# 创建 Saver 对象
saver = tf.train.Saver()

# 生成会话,训练STEPS轮
with tf.Session() as sess:
    # 初始化参数
    sess.run(tf.global_variables_initializer())
    
    ...... # 模型训练逻辑
            
    # 准备存储模型
    path = 'pb_model/'
    dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    pb_saver = tf.train.Saver(dense_model_var)
        
    builder = sm.builder.SavedModelBuilder(path)
    
    # 构建需要在新会话中恢复的变量的 TensorInfo protobuf
    # 自定义 根据自己的模型来写
    X = sm.utils.build_tensor_info(x)
    Y = sm.utils.build_tensor_info(y)
    Z = sm.utils.build_tensor_info(z)
    P = sm.utils.build_tensor_info(pred)

    # 构建 SignatureDef protobuf
    # inputs outputs 自定义 根据自己的模型来写
    SignatureDef = sm.signature_def_utils.build_signature_def(
                                inputs={'X': X, 'Y': Y, 'Z': Z},  # 可用sm.signature_constants.PREDICT_INPUTS
                                outputs={'P': P},  # 可用sm.signature_constants.PREDICT_OUTPUTS
                                method_name="tensorflow/serving/predict"
    )

    # 将 graph 和变量等信息写入 MetaGraphDef protobuf
    # 这里的 tags 里面的参数和 signature_def_map 字典里面的键都可以是自定义字符串,也可用tf里预设好的方便统一
    builder.add_meta_graph_and_variables(sess, tags=['serve'],
                                             signature_def_map={
                                                 sm.signature_constants.PREDICT_METHOD_NAME: SignatureDef},
                                             saver=pb_saver,
                                             main_op=tf.local_variables_initializer())

    # 将 MetaGraphDef 写入磁盘
    builder.save()

  最重要的是这一句:dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES),意思是保存当前作用域下的所有可训练的变量。

  我之前写的是dense_model_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name_scope="Model"),这样读不了所有的可训练变量,只能读到embedding_matrix 一个,虽然也能保存模型,但是没保存模型的其他变量值,就会出错。

模型加载代码

import tensorflow as tf
from tensorflow import saved_model as sm

tf.reset_default_graph()
# 创建一个新的默认图
graph = tf.Graph()

# 需要建立一个会话对象,将模型恢复到其中
with tf.Session(graph=graph) as sess:
    path = 'pb_model/'
    MetaGraphDef = sm.loader.load(sess, tags=['serve'], export_dir=path)

    # 解析得到 SignatureDef protobuf
    SignatureDef_map = MetaGraphDef.signature_def
    SignatureDef = SignatureDef_map[sm.signature_constants.PREDICT_METHOD_NAME]

    # 解析得到 3 个变量对应的 TensorInfo protobuf
    X = SignatureDef.inputs['X']
    Y = SignatureDef.inputs['Y']
    Z = SignatureDef.inputs['Z']
    P = SignatureDef.outputs['P']

    # 解析得到具体 Tensor
    # .get_tensor_from_tensor_info() 函数中可以不传入 graph 参数,TensorFlow 自动使用默认图
    # x = sm.utils.get_tensor_from_tensor_info(X)
    # y = sm.utils.get_tensor_from_tensor_info(Y)
    # z = sm.utils.get_tensor_from_tensor_info(Z)
    
    x = sess.graph.get_tensor_by_name(X.name)
    y = sess.graph.get_tensor_by_name(Y.name)
    z = sess.graph.get_tensor_by_name(Z.name)
    p = sess.graph.get_tensor_by_name(P.name)
    
    # 这里就可以开始进行预测或者继续训练了 TODO
    total_loss = sess.run(loss_function, feed_dict={x: dense_ch_val, y: sparse_ch_val, z: score_val})
    print(total_loss)
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

征途黯然.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值