TensorFlow:如何冻结模型并使用python API提供服务

TensorFlow:如何冻结模型并使用python API提供服务

 

我们将探讨在生产中使用ML模型的两个部分:

  • 如何导出模型并为其提供简单的自给自足文件
  • 如何使用TF构建一个简单的python服务器(使用flask)

注意:如果你想看到我保存/加载/冻结的图表类型,你可以在这里

如何冻结(导出)已保存的模型

如果您想知道如何使用TensorFlow保存模型,请在继续之前查看我之前的文章

让我们从包含模型的文件夹开始,它可能看起来像这样:

 

 

冻结模型之前生成的文件夹的屏幕截图

这里最重要的文件是“ .chkp”的人如果你记得很清楚,对于不同时间步的每一对,一个持有权重(“.data”)而另一个(“.meta”)持有图形及其所有元数据(所以你可以重新训练它等等)

但是当我们想要在生产中提供模型时,我们不需要任何特殊的元数据来混淆我们的文件,我们只希望我们的模型及其权重很好地打包在一个文件中。这有助于您的不同型号的存储,版本控制和更新。

幸运的是,在TF中,我们可以轻松地构建自己的功能来实现它。让我们探讨一下我们必须执行的不同步骤:

  • 检索我们保存的图:我们需要在默认图中加载以前保存的元图并检索其graph_def(图的ProtoBuf定义)
  • 恢复权重:我们启动一个会话并恢复该会话内图表的权重
  • 删除所有无用于推理的元数据:在这里,TF帮助我们提供了一个很好的帮助函数,它可以抓取图形中需要执行推理的内容并返回我们称之为新的“冻结graph_def”的内容。
  • 将其保存到磁盘:最后,我们将序列化我们的冻结graph_def ProtoBuf并将其转储到磁盘

请注意,前两个步骤与我们在TF中加载任何图形时相同,唯一棘手的部分实际上是图形“冻结”,TF具有内置函数来执行此操作!

我提供了一个稍微不同的版本,这个版本更简单,我觉得很方便。TF提供的原始freeze_graph函数安装在您的bin目录中,如果您使用PIP安装TF,则可以直接调用。如果没有,您可以直接从其文件夹中调用它(请参阅要点中的注释导入)。

所以让我们看看:

import os, argparse

import tensorflow as tf

# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph 

dir = os.path.dirname(os.path.realpath(__file__))

def freeze_graph(model_dir, output_node_names):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    """
    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % model_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path
    
    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=str, default="", help="Model folder to export")
    parser.add_argument("--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.")
    args = parser.parse_args()

    freeze_graph(args.model_dir, args.output_node_names)

现在我们可以在文件夹中看到一个新文件:“frozen_model.pb”。

冻结模型后生成的文件夹的屏幕截图

正如所料,它的大小比权重文件大小大,并且低于两个检查点文件大小的总和。

注意:在这个非常简单的情况下,权重文件大小非常小,但通常是多个Mbs。

如何使用冷冻模型

当然,在知道如何冻结模型之后,人们可能想知道如何使用它。

要记住的小技巧是要理解我们转储到磁盘的是graph_def ProtoBuf。因此,要将其导入python脚本,我们需要:

  • 首先导入graph_def ProtoBuf
  • 将此graph_def加载到实际的Graph中

我们可以建立一个方便的功能:

import tensorflow as tf

def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name="prefix")
    return graph

现在我们构建了加载冻结模型的函数,让我们创建一个简单的脚本来最终使用它:

import argparse 
import tensorflow as tf

if __name__ == '__main__':
    # Let's allow the user to pass the filename as an argument
    parser = argparse.ArgumentParser()
    parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")
    args = parser.parse_args()

    # We use our "load_graph" function
    graph = load_graph(args.frozen_model_filename)

    # We can verify that we can access the list of operations in the graph
    for op in graph.get_operations():
        print(op.name)
        # prefix/Placeholder/inputs_placeholder
        # ...
        # prefix/Accuracy/predictions
        
    # We access the input and output nodes 
    x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
    y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')
        
    # We launch a Session
    with tf.Session(graph=graph) as sess:
        # Note: we don't nee to initialize/restore anything
        # There is no Variables in this graph, only hardcoded constants 
        y_out = sess.run(y, feed_dict={
            x: [[3, 5, 7, 4, 5, 1, 1, 1, 1, 1]] # < 45
        })
        # I taught a neural net to recognise when a sum of numbers is bigger than 45
        # it should return False in this case
        print(y_out) # [[ False ]] Yay, it works!

注意:加载冻结模型时,所有操作都以“prefix”为前缀。这是由于“import_graph_def”函数中的参数“name”,默认情况下它以“import”为前缀。

如果要在现有图形中导入graph_def,这对于避免名称冲突很有用。

如何构建(非常)简单的API

对于这部分,我将让代码说明一切。毕竟这是关于TF的TF系列而不是如何在python中构建服务器。然而,如果没有它,它会感觉有点未完成,所以在这里,最后的工作流程:

import json, argparse, time

import tensorflow as tf
from load import load_graph

from flask import Flask, request
from flask_cors import CORS

##################################################
# API part
##################################################
app = Flask(__name__)
cors = CORS(app)
@app.route("/api/predict", methods=['POST'])
def predict():
    start = time.time()
    
    data = request.data.decode("utf-8")
    if data == "":
        params = request.form
        x_in = json.loads(params['x'])
    else:
        params = json.loads(data)
        x_in = params['x']

    ##################################################
    # Tensorflow part
    ##################################################
    y_out = persistent_sess.run(y, feed_dict={
        x: x_in
        # x: [[3, 5, 7, 4, 5, 1, 1, 1, 1, 1]] # < 45
    })
    ##################################################
    # END Tensorflow part
    ##################################################
    
    json_data = json.dumps({'y': y_out.tolist()})
    print("Time spent handling the request: %f" % (time.time() - start))
    
    return json_data
##################################################
# END API part
##################################################

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")
    parser.add_argument("--gpu_memory", default=.2, type=float, help="GPU memory per process")
    args = parser.parse_args()

    ##################################################
    # Tensorflow part
    ##################################################
    print('Loading the model')
    graph = load_graph(args.frozen_model_filename)
    x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
    y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')

    print('Starting Session, setting the GPU memory usage to %f' % args.gpu_memory)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory)
    sess_config = tf.ConfigProto(gpu_options=gpu_options)
    persistent_sess = tf.Session(graph=graph, config=sess_config)
    ##################################################
    # END Tensorflow part
    ##################################################

    print('Starting the API')
    app.run()

注意:我们在这个例子中使用了flask


TensorFlow最佳实践系列

本文是关于TensorFlow的更完整系列文章的一部分。我还没有定义这个系列的所有不同主题,所以如果你想看到TensorFlow的任何领域,请添加评论!到目前为止,我想探索这些主题(这个列表可能会有所变化,并没有特别的顺序):

注意: TF现在正在快速发展,这些文章目前是为1.0.0版本编写的。


参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值