本文中我们在Mac机器上使用Docker配置TensorFlow-Serving环境,并提供Http预测接口。
安装Docker
brew cask install docker
下载TensorFlow-Serving镜像
docker pull tensorflow/serving
生成SavedModel模型
TensorFlow主要有三种模型格式:CheckPoint(.ckpt),SavedModel,GraphDef(*.pb)。这三种格式之间可以互相转换,CheckPoint格式在训练模型时候每隔几轮保存一次,以方便增量训练。GraphDef格式适用于python、java的tensorflow库进行加载,SavedModel是TensorFlow-Serving要求的格式。我们最开始的是一个xception.pb模型(GraphDef格式),这里需要将其转换为SavedModel格式,代码如下:
import tensorflow.compat.v1 as tf
import time
tf.disable_v2_behavior()
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
export_dir = '/Work/infra/tensorflow/saved_model'
graph_pb = '/Work/infra/tensorflow/xception.pb'
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name("input_1:0") //输入节点名字
out = g.get_tensor_by_name("output:0") // 输出节点名字
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
tf.saved_model.signature_def_utils.predict_signature_def(
{"in": inp}, {"out": out})
builder.add_meta_graph_and_variables(sess,
[tag_constants.SERVING],
signature_def_map=sigs)
builder.save()
转换完成后,目录下除了saved_model.pb文件,还多了一个variables文件夹,该文件夹为空,因为从pb文件转换过来的时候全是常量没有变量。
启动TensorFlow-Serving
docker run -p 8501:8501 --name tfserving_testnet --mount type=bind,source=/Wor
k/infra/tensorflow/xception,target=/models/xception -e MODEL_NAME=xception -t tensorflow/serving
- 8051:http端口,前面的为本机的端口,后面的为docker中的端口。
- name:名字随便起,为了识别docker的container
- source:模型在本机上的位置目录
- target:/models/固定,后面的名字随便起,最好和模型名字一致
- MODEL_NAME:设置Docker中的环境变量,和上面的target名字一致
Http接口
- 查看TensorFlow-Serving状态:curl http://localhost:8501/v1/models/xception
- 查看TensorFlow-Serviing模型:curl http://localhost:8501/v1/models/xception/metadata
- 使用Http请求进行模型预测:curl -d ‘{“instances”: [1,2,3,4,5]}’ -X POST http://localhost:8501/v1/models/xception:predict,其中instances的value为模型输入Tensor的字符串形式,矩阵维度需要和Tensor对应。
Python客户端
这里我们使用Python加载图片数据,并发向TensorFlow Http接口进行预测图片质量。Http接口的数据需要是Json格式,代码如下:
SERVER_URL = 'http://localhost:8501/v1/models/xception:predict'
def prediction():
images = image.load_img("test.jpg", target_size=(480, 480))
x = image.img_to_array(images)
x = np.expand_dims(x, axis=0)
image_np = xception.preprocess_input(x)
#print str(image_np.tolist())
predict_request='{"instances":%s}' % str(image_np.tolist())
#predict_request='{"instances":%s}' % str([[[[1]*3]*480]*480])
response = requests.post(SERVER_URL, data=predict_request)
prediction = response.json()
print(prediction)
if __name__ == "__main__":
prediction()