TensorFlow:如何冻结模型并使用python API提供服务
我们将探讨在生产中使用ML模型的两个部分:
- 如何导出模型并为其提供简单的自给自足文件
- 如何使用TF构建一个简单的python服务器(使用flask)
注意:如果你想看到我保存/加载/冻结的图表类型,你可以在这里
如何冻结(导出)已保存的模型
如果您想知道如何使用TensorFlow保存模型,请在继续之前查看我之前的文章。
让我们从包含模型的文件夹开始,它可能看起来像这样:
冻结模型之前生成的文件夹的屏幕截图
这里最重要的文件是“ .chkp”的人。如果你记得很清楚,对于不同时间步的每一对,一个持有权重(“.data”)而另一个(“.meta”)持有图形及其所有元数据(所以你可以重新训练它等等)
但是当我们想要在生产中提供模型时,我们不需要任何特殊的元数据来混淆我们的文件,我们只希望我们的模型及其权重很好地打包在一个文件中。这有助于您的不同型号的存储,版本控制和更新。
幸运的是,在TF中,我们可以轻松地构建自己的功能来实现它。让我们探讨一下我们必须执行的不同步骤:
- 检索我们保存的图:我们需要在默认图中加载以前保存的元图并检索其graph_def(图的ProtoBuf定义)
- 恢复权重:我们启动一个会话并恢复该会话内图表的权重
- 删除所有无用于推理的元数据:在这里,TF帮助我们提供了一个很好的帮助函数,它可以抓取图形中需要执行推理的内容并返回我们称之为新的“冻结graph_def”的内容。
- 将其保存到磁盘:最后,我们将序列化我们的冻结graph_def ProtoBuf并将其转储到磁盘
请注意,前两个步骤与我们在TF中加载任何图形时相同,唯一棘手的部分实际上是图形“冻结”,TF具有内置函数来执行此操作!
我提供了一个稍微不同的版本,这个版本更简单,我觉得很方便。TF提供的原始freeze_graph函数安装在您的bin
目录中,如果您使用PIP安装TF,则可以直接调用。如果没有,您可以直接从其文件夹中调用它(请参阅要点中的注释导入)。
所以让我们看看:
import os, argparse
import tensorflow as tf
# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph
dir = os.path.dirname(os.path.realpath(__file__))
def freeze_graph(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert
all its variables into constant
Args:
model_dir: the root folder containing the checkpoint state file
output_node_names: a string, containing all the output node's names,
comma separated
"""
if not tf.gfile.Exists(model_dir):
raise AssertionError(
"Export directory doesn't exists. Please specify an export "
"directory: %s" % model_dir)
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
output_graph = absolute_model_dir + "/frozen_model.pb"
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We start a session using a temporary fresh Graph
with tf.Session(graph=tf.Graph()) as sess:
# We import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)
# We restore the weights
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
return output_graph_def
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="", help="Model folder to export")
parser.add_argument("--output_node_names", type=str, default="", help="The name of the output nodes, comma separated.")
args = parser.parse_args()
freeze_graph(args.model_dir, args.output_node_names)
现在我们可以在文件夹中看到一个新文件:“frozen_model.pb”。
冻结模型后生成的文件夹的屏幕截图
正如所料,它的大小比权重文件大小大,并且低于两个检查点文件大小的总和。
注意:在这个非常简单的情况下,权重文件大小非常小,但通常是多个Mbs。
如何使用冷冻模型
当然,在知道如何冻结模型之后,人们可能想知道如何使用它。
要记住的小技巧是要理解我们转储到磁盘的是graph_def ProtoBuf。因此,要将其导入python脚本,我们需要:
- 首先导入graph_def ProtoBuf
- 将此graph_def加载到实际的Graph中
我们可以建立一个方便的功能:
import tensorflow as tf
def load_graph(frozen_graph_filename):
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we import the graph_def into a new Graph and returns it
with tf.Graph().as_default() as graph:
# The name var will prefix every op/nodes in your graph
# Since we load everything in a new graph, this is not needed
tf.import_graph_def(graph_def, name="prefix")
return graph
现在我们构建了加载冻结模型的函数,让我们创建一个简单的脚本来最终使用它:
import argparse
import tensorflow as tf
if __name__ == '__main__':
# Let's allow the user to pass the filename as an argument
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")
args = parser.parse_args()
# We use our "load_graph" function
graph = load_graph(args.frozen_model_filename)
# We can verify that we can access the list of operations in the graph
for op in graph.get_operations():
print(op.name)
# prefix/Placeholder/inputs_placeholder
# ...
# prefix/Accuracy/predictions
# We access the input and output nodes
x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')
# We launch a Session
with tf.Session(graph=graph) as sess:
# Note: we don't nee to initialize/restore anything
# There is no Variables in this graph, only hardcoded constants
y_out = sess.run(y, feed_dict={
x: [[3, 5, 7, 4, 5, 1, 1, 1, 1, 1]] # < 45
})
# I taught a neural net to recognise when a sum of numbers is bigger than 45
# it should return False in this case
print(y_out) # [[ False ]] Yay, it works!
注意:加载冻结模型时,所有操作都以“prefix”为前缀。这是由于“import_graph_def”函数中的参数“name”,默认情况下它以“import”为前缀。
如果要在现有图形中导入graph_def,这对于避免名称冲突很有用。
如何构建(非常)简单的API
对于这部分,我将让代码说明一切。毕竟这是关于TF的TF系列而不是如何在python中构建服务器。然而,如果没有它,它会感觉有点未完成,所以在这里,最后的工作流程:
import json, argparse, time
import tensorflow as tf
from load import load_graph
from flask import Flask, request
from flask_cors import CORS
##################################################
# API part
##################################################
app = Flask(__name__)
cors = CORS(app)
@app.route("/api/predict", methods=['POST'])
def predict():
start = time.time()
data = request.data.decode("utf-8")
if data == "":
params = request.form
x_in = json.loads(params['x'])
else:
params = json.loads(data)
x_in = params['x']
##################################################
# Tensorflow part
##################################################
y_out = persistent_sess.run(y, feed_dict={
x: x_in
# x: [[3, 5, 7, 4, 5, 1, 1, 1, 1, 1]] # < 45
})
##################################################
# END Tensorflow part
##################################################
json_data = json.dumps({'y': y_out.tolist()})
print("Time spent handling the request: %f" % (time.time() - start))
return json_data
##################################################
# END API part
##################################################
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")
parser.add_argument("--gpu_memory", default=.2, type=float, help="GPU memory per process")
args = parser.parse_args()
##################################################
# Tensorflow part
##################################################
print('Loading the model')
graph = load_graph(args.frozen_model_filename)
x = graph.get_tensor_by_name('prefix/Placeholder/inputs_placeholder:0')
y = graph.get_tensor_by_name('prefix/Accuracy/predictions:0')
print('Starting Session, setting the GPU memory usage to %f' % args.gpu_memory)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory)
sess_config = tf.ConfigProto(gpu_options=gpu_options)
persistent_sess = tf.Session(graph=graph, config=sess_config)
##################################################
# END Tensorflow part
##################################################
print('Starting the API')
app.run()
注意:我们在这个例子中使用了flask
TensorFlow最佳实践系列
本文是关于TensorFlow的更完整系列文章的一部分。我还没有定义这个系列的所有不同主题,所以如果你想看到TensorFlow的任何领域,请添加评论!到目前为止,我想探索这些主题(这个列表可能会有所变化,并没有特别的顺序):
- 一本底漆
- 如何在TensorFlow中处理形状
- TensorFlow保存/恢复和混合多个模型
- 如何冻结模型并使用python(这一个!)提供服务
- TensorFlow:文件,文件夹和模型架构的良好实践建议
- TensorFlow howto:神经网络内的通用逼近器
- 如何使用队列和多线程优化输入管道
- 变异变量和控制流程
- 如何使用TensorFlow处理输入数据。
- 如何控制渐变以创建自定义反向支撑或微调我的模型。
- 如何监控和检查我的模型以深入了解它们。
注意: TF现在正在快速发展,这些文章目前是为1.0.0版本编写的。