将tflearn的模型保存为pb,给TensorFlow使用

原博文

最后保存的pb结果


注意:
原博文说的非常详细正确亲测可用。主要会出现的问题就两点:
(1)在model.save之前、清除图中的op操作、需要修改训练的脚本,加上这句话:

(2)输出所有的node_names,修改成自己的node_names。我修改的倒是跟博主一样:

原博文内容转载如下:

参考:https://github.com/tflearn/tflearn/issues/964

解决方法:

复制代码
"""
Tensorflow graph freezer
Converts Tensorflow trained models in .pb

Code adapted from:
https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py
"""

import os, argparse
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(model_folder,output_graph="frozen_model.pb"):
    # We retrieve our checkpoint fullpath
    try:
        checkpoint = tf.train.get_checkpoint_state(model_folder)
        input_checkpoint = checkpoint.model_checkpoint_path
        print("[INFO] input_checkpoint:", input_checkpoint)
    except:
        input_checkpoint = model_folder
        print("[INFO] Model folder", model_folder)

    # Before exporting our graph, we need to precise what is our output node
    # This is how TF decides what part of the Graph he has to keep and what part it can dump
    output_node_names = "FullyConnected/Softmax" # NOTE: Change here

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True
    
    # We import the meta graph and retrieve a Saver
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We start a session and restore the graph weights
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,                        # The session is used to retrieve the weights
            input_graph_def,             # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

        print("[INFO] output_graph:",output_graph)
        print("[INFO] all done")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Tensorflow graph freezer\nConverts trained models to .pb file",
                                     prefix_chars='-')
    parser.add_argument("--mfolder", type=str, help="model folder to export")
    parser.add_argument("--ograph", type=str, help="output graph name", default="frozen_model.pb")
    
    args = parser.parse_args()
    print(args,"\n")

    freeze_graph(args.mfolder,args.ograph)

# However, before doing model.save(...) on TFLearn i have to do
# ************************************************************
# del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
# ************************************************************

"""
Then I call this command
python tf_freeze.py --mfolder=<path_to_tflearn_model>

Note

    The <path_to_tflearn_model> must not have the ".data-00000-of-00001".
    The output_node_names variable may change depending on your architecture. The thing is that you must reference the layer that has the softmax activation function.
"""
复制代码

注意:

1、需要在 tflearn的model.save 前:

del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

作用:去除模型里训练OP。

参考:https://github.com/tflearn/tflearn/issues/605#issuecomment-298478314

 2、如果是有batch normalzition,或者残差网络层,会出现:

Error when loading the frozen graph with tensorflow.contrib.layers.python.layers.batch_norm
ValueError: graph_def is invalid at node u'BatchNorm/cond/AssignMovingAvg/Switch': Input tensor 'BatchNorm/moving_mean:0' Cannot convert a tensor of type float32 to an input of type float32_ref
freeze_graph.py doesn't seem to store moving_mean and moving_variance properly

 

An ugly way to get it working:
manually replace the wrong node definitions in the frozen graph
RefSwitch --> Switch + add '/read' to the input names
AssignSub --> Sub + remove use_locking attributes

 

则需要在restore模型后加入:

1
2
3
4
5
6
7
8
9
10
# fix batch norm nodes
for   node  in   gd.node:
   if   node.op  = =   'RefSwitch' :
     node.op  =   'Switch'
     for   index  in   xrange ( len (node. input )):
       if   'moving_'   in   node. input [index]:
         node. input [index]  =   node. input [index]  +   '/read'
   elif   node.op  = =   'AssignSub' :
     node.op  =   'Sub'
     if   'use_locking'   in   node.attr:  del   node.attr[ 'use_locking' ]

 参考:https://github.com/tensorflow/tensorflow/issues/3628

 

I met the same issue when I was trying to export graph and variables by saved_model module. And finally I found a walk around to fix this issue:

 

Remove the TRAIN_OPS collections from graph collection. e.g.:

 

with dnn.graph.as_default():
     del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

 

The dumped graph may not be available for training again (by tflearn), but should be able to perform prediction and evaluation. This is useful when serving model by another module or language (e.g. tensorflow serving or tensorflow go binding). I'll do more further tests about this.

 

If you wanna re-train the model, please use the builtin "save" method and re-construction the graph and load the saved data when re-training.

 

2、可能需要在代码修改这行,

output_node_names = "FullyConnected/Softmax" # NOTE: Change here


参考:https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py

@vparikh10 @ratfury @rakashi I faced the same situation just like you.
From what I understood, you may have to change this line according to your network definition.
In my case, instead of having output_node_names = "Accuracy/prediction", I have output_node_names = "FullyConnected_2/Softmax".

softmax

I made this change after reading this suggestion


对我自己而言,写成softmax或者Softmax都是不行的!然后我将所有的node names打印出来: 打印方法: 
复制代码
    with tf.Session() as sess:
            model = get_cnn_model(max_len, volcab_size)
            model.fit(trainX, trainY, validation_set=(testX, testY), show_metric=True, batch_size=1000, n_epoch=1)
            init_op = tf.initialize_all_variables()
            sess.run(init_op)

            for v in sess.graph.get_operations():
                print(v.name)
复制代码

然后确保output_node_names在里面。



附:gist里的代码,将output node names转换为参数

复制代码
import os, argparse

import tensorflow as tf

# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph 

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(model_dir, output_node_names):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    """
    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % model_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path
    
    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default="", help="Model folder to export")
    parser.add_argument("--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.")
    args = parser.parse_args()

freeze_graph(args.model_dir, args.output_node_names)
复制代码

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值