tensorflow模型固化与测试

模型固化:

# -*- coding:utf-8 -*-
import os, argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath('__file__'))


def freeze_graph(model_folder):
    # We retrieve our model-7000.ckpt fullpath
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    # **********修改.meta文件路径************
    input_checkpoint = "E:/python_pycharm/temp/z-intent recognition/output/char_rnn_polite_rude1_2/checkpoints/model-7000"
    # We precise the file fullname of our freezed graph
    absolute_model_folder = "/".join(input_checkpoint.split('/')[:-1])
    # **********修改固化后的模型路径************
    output_graph = "./model/frozen_model1.pb"
    # Before exporting our graph, we need to precise what is our output node
    # this variables is plural, because you can have multiple output nodes
    # freeze之前必须明确哪个是输出结点,也就是我们要得到推论结果的结点
    # 输出结点可以看我们模型的定义
    # 只有定义了输出结点,freeze才会把得到输出结点所必要的结点都保存下来,或者哪些结点可以丢弃
    # 所以,output_node_names必须根据不同的网络进行修改
    # **********修改固化后输出的节点************
    output_node_names = ['output/predictions_label']
    # We clear the devices, to allow TensorFlow to control on the loading where it wants operations to be calculated
    clear_devices = True
    # We import the meta graph and retrive a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    # We start a session and restore the graph weights
    # 这边已经将训练好的参数加载进来,也即最后保存,name='output'的模型是有图,并且图里面已经有参数了,所以才叫做是frozen
    # 相当于将参数已经固化在了图当中
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        print("tf")
        # 此处会打印所有结点,固化后只会保留相关结点
        for op in tf.get_default_graph().get_operations():
            wr_path = "./test1.txt"
            with open(wr_path, 'w', encoding='utf-8') as f:
                f.write(str(op.name))
                f.write("\n")
                f.write("***********")
                f.write("\n")
                f.write(str(op.values))
                f.write("\n")
                f.write("\n")
            print(op.name)
            print("===")
            print(op.values)
            print("\n")
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names=output_node_names  # We split on comma for convenience
        )

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # **********修改default路径************
    parser.add_argument("--model_folder", default="E:/python_pycharm/temp/z-intent recognition/output/char_rnn_polite_rude1_2/checkpoints/", type=str, help="Model folder to export")
    args = parser.parse_args()
    freeze_graph(args.model_folder)

模型测试:

# -*- coding:utf-8 -*-
import argparse
import tensorflow as tf
from numpy import *


def load_graph(frozen_graph_filename):
    # We parse the graph_def file
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # We load the graph_def in the default graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="prefix",
            op_dict=None,
            producer_op_list=None
        )
    return graph


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # **********修改default加载模型的路径************
    parser.add_argument("--frozen_model_filename", default="./model/frozen_model.pb", type=str,
                        help="Frozen model file to import")
    args = parser.parse_args()
    # 加载已经将参数固化后的图
    graph = load_graph(args.frozen_model_filename)

    # We can list operations
    # op.values() gives you a list of tensors it produces
    # op.name gives you the name
    # 输入,输出结点也是operation,所以,我们可以得到operation的名字
    wr_path = "./test.txt"
    f = open(wr_path, 'w', encoding='utf-8')
    for op in graph.get_operations():
        print(op.name, op.values())
        f.write(str(op.name))
        f.write("\n")
        f.write("***********")
        f.write("\n")
        f.write(str(op.values))
        f.write("\n")
        f.write("\n")
    f.close()
        # prefix/Placeholder/inputs_placeholder
        # ...
        # prefix/Accuracy/predictions
    # 操作有:prefix/Placeholder/inputs_placeholder
    # 操作有:prefix/Accuracy/predictions
    # 为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字
    # 注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字
    # ChatInputs = graph.get_tensor_by_name('prefix/ChatInputs:0')
    # transitions = graph.get_tensor_by_name('prefix/crf_loss/transitions:0')
    # Dropout = graph.get_tensor_by_name('prefix/Dropout:0')
    # Targets = graph.get_tensor_by_name('prefix/Targets:0')
    #
    # with tf.Session(graph=graph) as sess:
    #     result = sess.run(transitions, feed_dict={
    #         ChatInputs: [[1,2,3]], Dropout: 1.0, Targets: [1,2,3]} )
    #     print(result)
    # print ("finish")

	# **********根据节点name获取节点************
    Input = graph.get_tensor_by_name('prefix/x:0')
    Output = graph.get_tensor_by_name('prefix/output/predictions_label:0')
    Training = graph.get_tensor_by_name('prefix/is_training:0')

	# **********修改run()方法参数************
    data1 = mat(zeros((1, 1000)))
    with tf.Session(graph=graph) as sess:
        result = sess.run(Output, feed_dict={
            Input: data1,Training:False})
        print(result)
    print("finish")

参考链接:
https://blog.csdn.net/yuan_da_xian/article/details/83868025
https://www.jianshu.com/p/091415b114e2

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值