之前有人说怎么将t2t的训练模型部署起来,其实不难!
首先,是安装tensorflow-model-server 可以自行百度!
然后进行下列操作:
这里假设你已经有了训练好的t2t模型
模型导出:
t2t-exporter \
--t2t_usr_dir=$T2T_USR_DIR \
--model=$MODEL \
--hparams_set=$HPARAMS \
--problem=$PROBLEM \
--data_dir=$DATA_DIR \
--output_dir=$TRAIN_DIR
模型部署:
tensorflow_model_base_server \
--port=9000 \
--model_name=my_model \
--model_base_path=$TRAIN_DIR/export/Servo
注意这里的tensor2tensor版本是 1.7.x
然后进行模型的调用,也就是使用模型进行预测:
这里是根据官方代码修改而来:只需在请求处理函数中调用下面的这个函数就行!
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from oauth2client.client import GoogleCredentials
from six.moves import input # pylint: disable=redefined-builtin
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
from tensor2tensor.serving import serving_utils
from tensor2tensor.utils import registry
from tensor2tensor.utils import usr_dir
import tensorflow as tf
def make_request_fn(server_name, server_address):
"""Returns a request function."""
request_fn = serving_utils.make_grpc_request_fn(
servable_name=server_name,
server=server_address,
timeout_secs=10)
return request_fn
def query_t2t(input_txt, data_dir, problem_name, server_name, server_address, t2t_usr_dir):
usr_dir.import_usr_dir(t2t_usr_dir)
problem = registry.problem(problem_name)
hparams = tf.contrib.training.HParams(
data_dir=os.path.expanduser(data_dir))
problem.get_hparams(hparams)
request_fn = make_request_fn(server_name, server_address)
inputs = input_txt
outputs = serving_utils.predict([inputs], problem, request_fn)
output, score = outputs
return output, score
上面的函数参数,分别是:
输入内容、数据文件夹,问题名称,服务名称,服务的地址,自定义文件夹
其中这些参数都是和上述命令中对应的,端口是默认9000,可以根据需要进行更改,欢迎留言评论讨论
如果是用的docker进行部署,注意端口的映射,容器内端口映射到服务端口