目录
1:安装docker
windows中需要先安装WLS2,linux直接安装
2:docker pull tensorflow/serving
3:基于docker启动tensorflow serving
Windows
docker run -p 8500:8500 --mount type=bind,source=E:\pythonproject\app_backen210705\app_backen210705\app_backen\bpmonitor\savedmodel,target=/models/savedmodel -e MODEL_NAME=savedmodel -t tensorflow/serving
Linux
docker run -p 8501:8501 -v /opt/app_backen001/savedmodel_03081611:/models/savedmodel -e MODEL_NAME=savedmodel -t tensorflow/serving
注意
- source是本地模型的绝对路径,路径中不能有特殊符号,包括空格。
- 模型存放在savedmodel文件夹下面的一个名称为纯数字的文件夹内,这串数字表示这个模型的版本。而source路径中不包含这串数字。
- target路径中/models是固定的,后面的savedmodel应该与前面source路径里的名称一样,均为savedmodel,该名称被视为模型名称。
- MODEL_NAME就是模型名称。
- 如果使用restful进行通信,还需要
-p 8501:8501
启动后可以通过grpc或者restful传入数据并得到预测结果。
4: 使用tensorflow-serving的restful api进行交互
我使用的是restful的api
交互的时候要注意格式,严格按照模型的输入模型进行交互。
附上官网的api:https://tensorflow.google.cn/tfx/serving/api_rest#start_modelserver_with_the_rest_api_endpoint
当然使用gRPC交互也行。
下面是我使用restful api和tensorflow-serving通信交互的代码:
import requests
import numpy as np
import os
import json
from django.http import JsonResponse
class NumpyArrayEncoder(json.JSONEncoder):
def default(self,obj):
if isinstance(obj,np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self,obj)
def Savedmodel(input): # shape :(875,)
url = "http://127.0.0.1:8501/v1/models/savedmodel:predict"
input = [[x] for x in input] # shape:(875,1)
input_dict = {
"instances":[ { "input_1":input } ] # 这里非常容易错,且报错的提示内容很晦涩难懂,不要相信他报的什么rank的错
}
input_json = json.dumps(input_dict, cls = NumpyArrayEncoder)
response = requests.post(url,data=input_json)
if response.status_code == 200:
result = response.json()
# 输出result到日志
else:
# 报错
return result['predictions'][0]['xxx'][0]
这里的input对应的我的模型的metadata:
(metadata可以通过curl 或者 浏览器 访问http://127.0.0.1:8501/v1/models/savedmodel/metadata得到)