Tensorflow Serving部署及客户端访问编程实践

昨天我们实现了Tensorflow.js的花卉识别程序,它的优点是不需要服务器支持,在客户端就可以完成花卉识别,使用非常方便,但也存在一些缺点。对于很多深度学习的应用来说,由于其训练模型复杂、计算量大,所以,一般来说,仍然需要服务器支持。下面仍然以花卉识别为例,介绍如何部署Tensorflow Serving及客户端编程。

TensorFlow Serving 是由 Google 开发和维护的开源项目,是 TensorFlow 生态系统的一部分,专门用于高效地部署和服务机器学习模型,具有高性能、灵活性、易于集成、可扩展性、易于管理和健壮性等多方面的优点。最重要的是它与 TensorFlow 紧密集成,实现了与 TensorFlow 生态系统无缝集成,支持 TensorFlow 模型的完整生命周期管理,从训练到部署再到监控。并且能够直接加载和使用 TensorFlow 的 SavedModel 格式,无需额外的转换步骤。
相对于许多通用的 web 服务器和 API 服务器(如 Flask、Django、FastAPI 等),但 TensorFlow Serving 专门针对机器学习模型的服务进行了优化,包括高效的内存管理、请求批处理、多线程处理等,能够在高并发和高负载的场景下表现出色。

这里不介绍Tensorflow Serving的安装,只介绍与编程有关的部署等问题。

文末附完整源代码链接。

一、服务端部署训练模型

1. 配置模型

按如下目录存放训练SavedModel模型:

/path/to/your/model/
└── your_model/
    └── 1/
        ├── saved_model.pb
        └── variables/
            ├── variables.data-00000-of-00001
            └── variables.index

2. 启动 TensorFlow Serving

执行以下命令:

tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=your_model --model_base_path=/path/to/your/model/your_model

这条命令用于启动 TensorFlow Serving 服务器,加载指定的模型,并配置其服务端口和 API 端口。以下是每个参数的详细解释:

3. 命令和参数解释

tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=your_model --model_base_path=/path/to/your/model/your_model
  • tensorflow_model_server: 这是启动 TensorFlow Serving 服务器的命令。
  • --port=8500: 指定 gRPC API 的端口号。gRPC 是一种高性能的远程过程调用(RPC)框架,适用于需要高吞吐量和低延迟的应用场景。
  • --rest_api_port=8501: 指定 RESTful API 的端口号。RESTful API 基于 HTTP 协议,使用起来简单且广泛应用,方便客户端通过 HTTP 请求与 TensorFlow Serving 进行交互。
  • --model_name=your_model: 指定模型的名称。在服务中使用这个名称来引用和请求这个模型。这个名称可以在客户端请求中用来标识和调用特定的模型。
  • --model_base_path=/path/to/your/model/your_model: 指定模型所在的目录路径。TensorFlow Serving 会在这个目录中查找并加载模型。该路径应包含模型的文件和子目录。

二、客户端程序

1. 使用gRPC协议访问服务器

下面的代码实现了gRPC客户端 与 TensorFlow Serving 服务器交互。客户端对图片进行预处理后,向服务器发送请求,服务器完成花卉识别后,向客户端返回结果。以下是对关键代码的解释:

(1)图像预处理函数

def process_image(image: np.ndarray) -> np.ndarray:
    image_tensor = tf.convert_to_tensor(image)
    image_resized = tf.image.resize(image_tensor, (224, 224))
    image_resized /= 255

    return image_resized.numpy()
  • 将输入的图像数组转换为 TensorFlow 张量。
  • 调整图像大小为 (224, 224)
  • 将图像归一化到 [0, 1] 范围。
  • 返回预处理后的图像数组。

(2)加载和预处理图像的函数

def load_image(image_path):
    im = Image.open(image_path)
    image_arr = np.asarray(im)
    processed_image = process_image(image_arr)
    processed_image = np.expand_dims(processed_image, 0)
        
    return processed_image
  • 加载图像文件并转换为 NumPy 数组。
  • 调用 process_image 进行预处理。
  • 将图像扩展为 (1, 224, 224, 3) 形状,适应批处理输入。
  • 返回预处理后的图像。

(3)加载标签映射的函数

def load_label_map(label_map_path):
    with open(label_map_path, 'r', encoding='utf-8') as f:
        label_map = json.load(f)
    return label_map
  • 从 JSON 文件中加载标签映射。
  • 返回标签映射的字典。

(4)创建 gRPC 频道和存根

channel = grpc.insecure_channel('your server:8500')
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
  • 创建一个 gRPC 频道,连接到 TensorFlow Serving 服务。
  • 创建一个存根,用于与 TensorFlow Serving 进行通信。

(5)创建预测请求

request = predict_pb2.PredictRequest()
request.model_spec.name = 'ai_flower'
request.model_spec.signature_name = 'serving_default'
  • 创建一个 PredictRequest 对象。
  • 设置模型名称 ai_flower 和签名名称 serving_default

(6)读取和预处理图像

image_path = 'test_images/image_00250.jpg'
input_image = load_image(image_path)
  • 设置图像路径。
  • 调用 load_image 函数读取和预处理图像。

(7)设置请求输入张量

request.inputs['keras_layer_input'].CopyFrom(
    tf.make_tensor_proto(input_image, shape=input_image.shape))
  • 将预处理后的图像设置为请求的输入张量。

(8)发送请求并获取响应

response = stub.Predict(request)
  • 发送预测请求并获取响应。

(9)提取预测结果

output_tensor_name = 'dense'  # 修改为实际的键名
if output_tensor_name in response.outputs:
    predictions = tf.make_ndarray(response.outputs[output_tensor_name])
else:
    print(f"Output tensor '{output_tensor_name}' not found in the response.")
    predictions = []
  • 假设输出张量的键名是 dense,从响应中提取预测结果。
  • 如果键名不同,请根据实际情况进行修改。

2. 使用REST API协议访问服务器

与上述使用gRPC协议访问服务器实现的功能一样。以下只对有区别代码的进行解释:

(1)服务器 URL

server_url = 'http://your_server:8501/v1/models/ai_flower:predict'
  • 指定 TensorFlow Serving 服务器的 URL,发送预测请求到 ai_flower 模型的 predict 端点。

(2)发送 POST 请求到服务器

response = requests.post(server_url, json=data)
  • 通过 POST 请求将图像数据发送到 TensorFlow Serving 服务器。

(3)检查响应状态

if response.status_code == 200:
    result = response.json()
    predictions = np.array(result['predictions'])
    label_map = load_label_map('label_map.json')

    top_k = 5
    top_indices = np.argsort(predictions[0])[-top_k:][::-1]
    for i in top_indices:
        label_id = i + 1
        label_name = label_map.get(str(label_id), 'Unknown')
        confidence = predictions[0][i]
        print(f"label_id: {label_id}, Label: {label_name}, Confidence: {confidence:.4f}")
else:
    print(f"Request failed with status code {response.status_code}")
    print("Response:", response.text)
  • 检查响应状态码是否为 200(即请求成功)。
  • 解析响应 JSON 数据,提取预测结果。
  • 加载标签映射文件。
  • 获取前 5 位预测结果,打印每个预测类别的标签和信心分数。
  • 如果请求失败,打印状态码和响应内容。

完整源代码

  • 21
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值