将tf1训练的模型导入tf2进行推理

最近做比赛遇到一个问题,tf1训练的模型提交后因为环境问题导致线上无法运行,故尝试线上用tf2进行推理。

步骤需要1.将tf1的.pd模型结构和ckpt存档导出。2.将模型结构和ckpt存档转化为.pd静态图(frozen graph)。3.使用tf2读取.pd静态图(frozen graph)进行推理

首先定义基于tf1的模型

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


with tf.Session() as sess:
    a = tf.placeholder(tf.float32, [1])
    b = tf.placeholder(tf.float32, [1])
    c = tf.get_variable("w", [1])
    d = a*c
    out = tf.add(d, b)
    # 初始化变量
    sess.run(tf.variables_initializer(tf.global_variables()))
    # 保存图结构
    tf.train.write_graph(sess.graph_def, './', 'graph_define.pb', as_text=True)
    # 保存参数存档
    saver = tf.train.Saver()
    saver.save(sess, 'checkpoint.ckpt')

之后转化模型为静态图(frozen graph)

import tensorflow as tf
from tensorflow.python.tools import freeze_graph

# 转化模型
with tf.compat.v1.Session() as sess:
    freeze_graph.freeze_graph(
        input_graph='./graph_define.pb',
        input_saver='',
        input_binary=False,
        input_checkpoint='./checkpoint.ckpt',
        output_node_names='Add',
        restore_op_name='save/restore_all',
        filename_tensor_name='save/Const:0',
        output_graph='./frozen_model.pb',
        clear_devices=False,
        initializer_nodes=''
    )

其中input_graph是图结构文件,input_checkpoint是参数文件,output_node_names是输出节点名

需要注意的是模型训练时必须是基于原生tf1的,如果训练时使用tf.compat.v1而不引入tf.disable_v2_behavior(),会在转换时报错。

最后使用tf2加载静态图进行推理

import tensorflow as tf


def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")

    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph

    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs))


def load_pb(filename):
    # Load frozen graph using TensorFlow 1.x functions
    with tf.io.gfile.GFile(filename, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        loaded = graph_def.ParseFromString(f.read())

    # Wrap frozen graph to ConcreteFunctions
    frozen_func = wrap_frozen_graph(
        graph_def=graph_def,
        inputs=['Placeholder:0', 'Placeholder_1:0'],  # input tensor name of your model
        outputs="Add:0"  # output tensor name of your model
    )
    return frozen_func


model = load_pb('frozen_model.pb')

print(model(tf.constant(3, tf.float32), tf.constant(4, tf.float32)))

实测比赛代码也可以在线上运行了

参考资料:

TF 保存模型为 .pb格式 - 静悟生慧 - 博客园

tensorflow的三种保存格式总结-1(.ckpt) - 知乎

[深度学习] TensorFlow中模型的freeze_graph - 知乎

如何在TF2中使用TF1.x的.pb模型_MD笔记-CSDN博客

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值