python训练模型生产_tensorflow 模型部署生产环境

由于python的灵活性和完备的生态库,使得其成为实现、验证ML算法的不二之选。但是工业界要将模型部署到生产环境上,需要考略性能问题,就不建议再使用python端的服务。这个从训练到部署的整个流程如下图所示:

基本可以把工作分为三块:Saver端 模型的离线训练与导出

Serving端 模型加载与在线预测

Client端 构建请求

本文采用 Saver (python) + Serving (tensorflow serving) + Client (Java) 作为解决方案,从零开始记录线上模型部署流程。

1、Saver

部署模型第一步是将训练好的整个模型导出为一系列标准格式的文件,然后即可在不同的平台上部署模型文件。TensorFlow 使用 SavedModel(pb文件) 这一格式用于模型部署。与Checkpoint 不同,SavedModel 包含了一个 TensorFlow 程序的完整信息: 不仅包含参数的权值,还包含计算图。

SavedModel最终保存结果包含两部分saved_model.pb和variables文件夹。

此处分别介绍,Tensorflow 1.0 和 2.0两个版本的导出方法。

1.1 Tensorflow 1.0 export

个人认为官方文档对具体使用写得不是特别明白,不想看官方文档的同学,可以对着示例照葫芦画瓢。其实也很简单,就两件事:Step 1、创建

builder = tf.saved_model.builder.SavedModelBuilder("out_dir")

# define signature which specify input and out nodes

predict_sig_def = (saved_model.signature_def_utils.build_signature_def(

inputs={"input_x":saved_model.build_tensor_info(fast_model.input_x)},

outputs={"out_y": saved_model.build_tensor_info(fast_model.y_pred_cls),

"score": saved_model.build_tensor_info(fast_model.logits)},

method_name=saved_model.signature_constants.PREDICT_METHOD_NAME))

# add graph and variables

builder.add_meta_graph_and_variables(sess, ["serve"],

signature_def_map={"fastText_sig_def": predict_sig_def},

main_op=tf.compat.v1.tables_initializer(),

strip_default_attrs=True)

builder.save()

需要注意的是,此处保存时的signature、input、out的相关属性诸如:name(自定义,不用和图内节点名称相同)

shape

data type

应与Client端传参对应。

1.2 Tensorflow 2.0 export

Keras 模型均可方便地导出为 SavedModel 格式。不过需要注意的是,因为 SavedModel 基于计算图,所以对于使用继承 tf.keras.Model 类建立的 Keras 模型,其需要导出到 SavedModel 格式的方法(比如 call )都需要使用 @tf.function 修饰。

class MLP(tf.keras.Model):

def __init__(self):

super().__init__()

self.flatten = tf.keras.layers.Flatten()

self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)

self.dense2 = tf.keras.layers.Dense(units=10)

@tf.function

def call(self, inputs): # [batch_size, 28, 28, 1]

x = self.flatten(inputs) # [batch_size, 784]

x = self.dense1(x) # [batch_size, 100]

x = self.dense2(x) # [batch_size, 10]

output = tf.nn.softmax(x)

return output

model = MLP()

然后使用下面的代码即可将模型导出为 SavedModel

tf.saved_model.save(model, "保存的目标文件夹名称")

1.3 check SavedModel

如果想要检查保存的模型的SignatureDef、Inputs、Outputs等信息,可在cmd下使用命令:

saved_model_cli show --dir model_dir_path --all

2、Serving

模型保存好,就到Serving端的加载与预测步骤了。在介绍Tensorflow Serving之前,先介绍下基于 Tensorflow Java lib 的解决方案。

2.1 Tensorflow Java lib

Tensorflow提供了一个Java API(本质上是Java封装了C++的动态库), 允许在Java可以很方便的加载SavedModel, 并调用模型推理。

2.1.1 添加依赖

首先,在maven的pom.xml中添加依赖,此处tensorflow的版本最好与python训练版本一致。

org.tensorflow

tensorflow

1.11.0

2.1.2 Load & Predict

然后,加载模型,调用模型在线预测。以fast text模型为例,代码如下:

package model;

import org.tensorflow.SavedModelBundle;

import org.tensorflow.Session;

import org.tensorflow.Graph;

import org.tensorflow.Tensor;

public class FastTextModel {

SavedModelBundle tensorflowModelBundle;

Session tensorflowSession;

void load(String modelPath){

this.tensorflowModelBundle = SavedModelBundle.load(modelPath, "serve");

this.tensorflowSession = tensorflowModelBundle.session();

}

public Tensor predict(Tensor tensorInput){

// feed()传参类似python端的feed_dict // fetch()指定输出节点的名称 Tensor out

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值