由于python的灵活性和完备的生态库,使得其成为实现、验证ML算法的不二之选。但是工业界要将模型部署到生产环境上,需要考略性能问题,就不建议再使用python端的服务。这个从训练到部署的整个流程如下图所示:
基本可以把工作分为三块:Saver端 模型的离线训练与导出
Serving端 模型加载与在线预测
Client端 构建请求
本文采用 Saver (python) + Serving (tensorflow serving) + Client (Java) 作为解决方案,从零开始记录线上模型部署流程。
1、Saver
部署模型第一步是将训练好的整个模型导出为一系列标准格式的文件,然后即可在不同的平台上部署模型文件。TensorFlow 使用 SavedModel(pb文件) 这一格式用于模型部署。与Checkpoint 不同,SavedModel 包含了一个 TensorFlow 程序的完整信息: 不仅包含参数的权值,还包含计算图。
SavedModel最终保存结果包含两部分saved_model.pb和variables文件夹。
此处分别介绍,Tensorflow 1.0 和 2.0两个版本的导出方法。
1.1 Tensorflow 1.0 export
个人认为官方文档对具体使用写得不是特别明白,不想看官方文档的同学,可以对着示例照葫芦画瓢。其实也很简单,就两件事:Step 1、创建
builder = tf.saved_model.builder.SavedModelBuilder("out_dir")
# define signature which specify input and out nodes
predict_sig_def = (saved_model.signature_def_utils.build_signature_def(
inputs={"input_x":saved_model.build_tensor_info(fast_model.input_x)},
outputs={"out_y": saved_model.build_tensor_info(fast_model.y_pred_cls),
"score": saved_model.build_tensor_info(fast_model.logits)},
method_name=saved_model.signature_constants.PREDICT_METHOD_NAME))
# add graph and variables
builder.add_meta_graph_and_variables(sess, ["serve"],
signature_def_map={"fastText_sig_def": predict_sig_def},
main_op=tf.compat.v1.tables_initializer(),
strip_default_attrs=True)
builder.save()
需要注意的是,此处保存时的signature、input、out的相关属性诸如:name(自定义,不用和图内节点名称相同)
shape
data type
应与Client端传参对应。
1.2 Tensorflow 2.0 export
Keras 模型均可方便地导出为 SavedModel 格式。不过需要注意的是,因为 SavedModel 基于计算图,所以对于使用继承 tf.keras.Model 类建立的 Keras 模型,其需要导出到 SavedModel 格式的方法(比如 call )都需要使用 @tf.function 修饰。
class MLP(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(units=10)
@tf.function
def call(self, inputs): # [batch_size, 28, 28, 1]
x = self.flatten(inputs) # [batch_size, 784]
x = self.dense1(x) # [batch_size, 100]
x = self.dense2(x) # [batch_size, 10]
output = tf.nn.softmax(x)
return output
model = MLP()
然后使用下面的代码即可将模型导出为 SavedModel
tf.saved_model.save(model, "保存的目标文件夹名称")
1.3 check SavedModel
如果想要检查保存的模型的SignatureDef、Inputs、Outputs等信息,可在cmd下使用命令:
saved_model_cli show --dir model_dir_path --all
2、Serving
模型保存好,就到Serving端的加载与预测步骤了。在介绍Tensorflow Serving之前,先介绍下基于 Tensorflow Java lib 的解决方案。
2.1 Tensorflow Java lib
Tensorflow提供了一个Java API(本质上是Java封装了C++的动态库), 允许在Java可以很方便的加载SavedModel, 并调用模型推理。
2.1.1 添加依赖
首先,在maven的pom.xml中添加依赖,此处tensorflow的版本最好与python训练版本一致。
org.tensorflow
tensorflow
1.11.0
2.1.2 Load & Predict
然后,加载模型,调用模型在线预测。以fast text模型为例,代码如下:
package model;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Graph;
import org.tensorflow.Tensor;
public class FastTextModel {
SavedModelBundle tensorflowModelBundle;
Session tensorflowSession;
void load(String modelPath){
this.tensorflowModelBundle = SavedModelBundle.load(modelPath, "serve");
this.tensorflowSession = tensorflowModelBundle.session();
}
public Tensor predict(Tensor tensorInput){
// feed()传参类似python端的feed_dict // fetch()指定输出节点的名称 Tensor out