tensorflow的三种保存格式相互转换

首先三种模型导出:

  • tf.train.Saver()

用于保存和恢复Variable。它可以非常方便的保存当前模型的变量或者倒入之前训练好的变量。一个最简单的运用:

saver = tf.train.Saver()
# Save the variables to disk.
saver.save(sess, "/tmp/test.ckpt")
# Restore variables from disk.
saver.restore(sess, "/tmp/test.ckpt")
1. ckpt格式
#saver.save(sess, '../tf-model/', global_step=1, write_meta_graph=True)
2. Pb格式
      with tf.variable_scope("whichPun"):
            task_2 = tf.layers.dense(bert_output, units=5, activation=None, trainable=False)
            print("bert_output== ", task_2)
            task_2 = tf.cast(task_2, tf.float32)
            self.logit = tf.reshape(task_2, [-1, self.input_shape[1], 5], name='output')

在输出的scope离找到输出名字

 #2. 保存为pb 在sess中两行
frozen_graph_def = graph_util.convert_variables_to_constants(sess,
                                                             tf.get_default_graph().as_graph_def(),
                                                             ['whichPun/output'] )#注意此处是输出名字 为list才可以
with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
    f.write(frozen_graph_def.SerializeToString())
3. tfs格式Saved_model模块

saved_model_cli show --dir ./6 --all

Exporter 的基本使用方式是:

1)传入一个Saver实例;

2)调用init,定义模型的graph以及input/output

3)使用Exporter导出模型

        #3. 保存为tfs modle
       with tf.Graph().as_default() as graph:
           tf.import_graph_def(frozen_graph_def, name="", )
           with tf.Session() as sess:
               export_path = "savedmodel"
               if export_path:
                   os.system("rm -rf " + export_path)
                   # 恢复指定的tensor
                   builder = tf.saved_model.builder.SavedModelBuilder(export_path)
                   inids = tf.saved_model.utils.build_tensor_info(model.input_ids)
                   inmask = tf.saved_model.utils.build_tensor_info(model.input_mask)
                   poutput = tf.saved_model.utils.build_tensor_info(model.logit)


                   # signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
                   prediction_signature = (
                       tf.saved_model.signature_def_utils.build_signature_def(
                           inputs={'input': inids, 'mask': inmask},
                           outputs={'punc_output': poutput},
                           method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
                   # 导入graph与变量信息
                   builder.add_meta_graph_and_variables(
                       sess, [tf.saved_model.tag_constants.SERVING],
                       signature_def_map={
                           'ac_forward': prediction_signature,
                       })

                   builder.save()
#

模型相互转换

ckpt2pb.py
import tensorflow as tf
from sys import argv

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.import_meta_graph(argv[1] + '-1.meta', clear_devices=True)
ckpt_model_path = argv[1]
saver.restore(sess, tf.train.latest_checkpoint(ckpt_model_path))

graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(
          sess,
            input_graph_def,
           # ['smooth/smooth_output', 'whichPun/whichPun_output'] # We split on comma for convenience
            #['smooth/output', 'whichPun/output'] # We split on comma for convenience
            #['smooth/smooth_output'] # We split on comma for convenience
            #['whichPun/whichPun_output'] # We split on comma for convenience
            ['whichPun/output'] # We split on comma for convenience
              )
# # Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(argv[2], "wb") as f:
  f.write(output_graph_def.SerializeToString())
pb2tfs.py
#! encoding: utf-8

import numpy as np
from tensorflow.python.platform import gfile
import time
import os
import datetime
import tensorflow as tf
from sys import argv

from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_util

g1 = tf.Graph()
with g1.as_default() as g1:
    output_graph_def = tf.GraphDef()
    with gfile.FastGFile(argv[1], "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")

sess = tf.Session(graph=g1)

input = sess.graph.get_tensor_by_name("inputs/input_ids:0")
word_size = sess.graph.get_tensor_by_name("inputs/input_mask:0")
output = sess.graph.get_tensor_by_name("whichPun/output:0")
print(output)



tf.import_graph_def(output_graph_def, name="", )
with tf.Session() as sess:
    # 保存图模型
    export_path = "savedmodel"
    if  export_path:
        os.system("rm -rf " + export_path)

    builder = tf.saved_model.builder.SavedModelBuilder(export_path)
    data_in = tf.saved_model.utils.build_tensor_info(input)
    data_in2 = tf.saved_model.utils.build_tensor_info(word_size)
    data_out = tf.saved_model.utils.build_tensor_info(output)

    #signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
    prediction_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'input': data_in, 'mask':data_in2},
            outputs={'output': data_out},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    #导入graph与变量信息
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'ac_forward': prediction_signature,
        })

    builder.save()

    os.system("chmod -R 755 " + export_path)
graph2tfs.py
# coding=utf-8
import tensorflow as tf
import os
from sys import argv
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_util

# output_node_names = "whichPun/output:0"

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(argv[1] + '/-1.meta')
    saver.restore(sess, tf.train.latest_checkpoint(argv[1]))
    graph = sess.graph

    graph_def = sess.graph.as_graph_def()

    input_x = sess.graph.get_tensor_by_name("inputs/input_ids:0")
    print(input_x)
    input_mask = sess.graph.get_tensor_by_name("inputs/input_mask:0")
    print(input_mask)

    punc_out = sess.graph.get_tensor_by_name("whichPun/output:0")
    #smooth_out = sess.graph.get_tensor_by_name("smooth/output:0")


    # sess.run(graph.get_operation_by_name('Inputs/string_to_index/hash_table/table_init'))
    export_path = 'saved_model_no_freeze'
    # 保存图模型
    if export_path:
        os.system("rm -rf " + export_path)

        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        data_in = tf.saved_model.utils.build_tensor_info(input_x)
        data_mask = tf.saved_model.utils.build_tensor_info(input_mask)
        data_out_1 = tf.saved_model.utils.build_tensor_info(punc_out)
        # data_out_2 = tf.saved_model.utils.build_tensor_info(smooth_out)

        #        table_init = tf.group(tf.tables_initializer(), name='legacy_init_op')
        # signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'input': data_in, 'mask': data_mask},
                outputs={'output': data_out_1},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
        # 导入graph与变量信息
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'ac_forward': prediction_signature,
            })

        builder.save()
    exit()

    frozen_graph_def = graph_util.convert_variables_to_constants(sess,
                                                                 graph_def,
                                                                 ["whichPun/output"])
    # 模型保存成.pb格式
    with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
        f.write(frozen_graph_def.SerializeToString())

    exit()

    # writer = tf.summary.FileWriter("logs/", sess.graph)
    for op in graph.get_operations():
        print(op.name)

    # 固化模型
    frozen_graph_def = graph_util.convert_variables_to_constants(sess,
                                                                 tf.get_default_graph().as_graph_def(),
                                                                 [output_node_names])
    # 模型保存成.pb格式
    # with tf.gfile.FastGFile('graph.pb', mode='wb') as f:
    #    f.write(frozen_graph_def.SerializeToString())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(frozen_graph_def, name="", )
        with tf.Session() as sess:
            graph = sess.graph
            # for op in graph.get_operations():
            #    print(op.name)

            input_x = sess.graph.get_tensor_by_name("inputs/input_ids")
            input_leng = sess.graph.get_tensor_by_name("inputs/input_mask:0")
            final_out = sess.graph.get_tensor_by_name("whichPun/output:0")
            print('out1', final_out)

            export_path = "saved_model"
            # 保存图模型
            if export_path:
                os.system("rm -rf " + export_path)

                builder = tf.saved_model.builder.SavedModelBuilder(export_path)
                data_in = tf.saved_model.utils.build_tensor_info(input_x)
                data_length = tf.saved_model.utils.build_tensor_info(input_leng)
                data_out = tf.saved_model.utils.build_tensor_info(final_out)

                # signature_def将输入输出信息进行封装,在构建模型阶段可以随便给tensor命名
                prediction_signature = (
                    tf.saved_model.signature_def_utils.build_signature_def(
                        inputs={'ids': data_in, 'mask': data_length},
                        outputs={'output': data_out},
                        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
                # 导入graph与变量信息
                builder.add_meta_graph_and_variables(
                    sess, [tf.saved_model.tag_constants.SERVING],
                    signature_def_map={
                        'ac_forward': prediction_signature,
                    })

                builder.save()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值