tensorfow serving batching inference with java interface

文章目录


对于tensorflow serving部署模型,网上最常见的是基于python的接口教程。针对非python的情况,例如java,单张图片的inference教程也很多,然而一次处理多张图片的教程则较少。这里结合网上的资料,以及自己的尝试,着重介绍java下多张图片的处理方式。

一、单张图片预测

整个流程分为两部分,客户端传入图片数据,tensorflow serving端解析数据,详细代码比较多,这里只介绍最核心的数据处理部分。

1、客户端
import org.tensorflow.framework.TensorProto;
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.addStringVal(ByteString.copyFrom(imgBytes));
tensorProtoBuilder.setDtype(DataType.DT_STRING);

客户端是java代码,核心代码如上,主要是构造这个TensorProto的数据结构,并把图片数据放进去。其中的imgBytes是图片数据对应的byte数组,所以最终数据结构设置成了DT_STRING,后面会看到它也支持其它数据格式。其实这类proto格式的数据在各种语言下都可以使用,毕竟谷歌的protobuf是一种通用的数据格式。

2、服务端
import tensorflow as tf
image_raw = tf.placeholder(tf.string)
x = tf.image.decode_jpeg(image_raw)
# image pre-process
output = network_fn(x)

服务端则是tensorflow serving的一些具体代码了,可以在网上搜到相关教程。这里介绍如何处理客户端传来的图片流数据,这里很简单,直接decode即可。

以上就是单张图片的处理了,可以看到是非常简单的。然而为了提升效率,线上部署往往需要送入一个batch的图片,这个时候该怎么写,后面废了我好大劲才搞清楚。


二、多张图片预测(batching)

关于多张图片的预测,可以分为客户端batching和服务端batching,可以简单看看这里。针对服务端的batching,在python下只需设置一个参数即可,但在不同图片尺寸情况下,这种batching会自动把所有图片padding到统一尺寸,这不是我想要的方式。另一种则是客户端batching,操作起来相对更灵活。

下面介绍两种客户端batching的方式,也是参考了这里的回答

1、解析图片数据,制作proto的方式
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import java.awt.image.BufferedImage;
import javax.imageio.ImageIO;
import java.io.ByteArrayInputStream;
import org.tensorflow.framework.TensorShapeProto;

int imageHeight = 224;
int imageWidth = 224;
int num_images = 32;
// 注意需要将图片resize到统一尺寸,这里就不细写了
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
int[][][][] featuresTensorData = new int[num_images][imageHeight][imageWidth][3];
for(int i=0 ;i<num_images; i++) {
    final byte[] imgBytes = (byte[]) requestData.get(i);

    try {
        BufferedImage image = ImageIO.read(new ByteArrayInputStream(imgBytes));

        int[][] imageArray = new int[imageHeight][imageWidth];
        for (int row = 0; row < imageHeight; row++) {
            for (int column = 0; column < imageWidth; column++) {
                imageArray[row][column] = image.getRGB(column, row);

                int pixel = image.getRGB(column, row);

                featuresTensorData[i][row][column][0] = (pixel >> 16) & 0xff;   // red
                featuresTensorData[i][row][column][1] = (pixel >> 8) & 0xff;    // green
                featuresTensorData[i][row][column][2] = pixel & 0xff;           // blue
            }
        }
    } catch (IOException e) {
        e.printStackTrace();
        System.exit(1);
    }

}

for (int i = 0; i < featuresTensorData.length; ++i) {
    for (int j = 0; j < featuresTensorData[i].length; ++j) {
        for (int k = 0; k < featuresTensorData[i][j].length; ++k) {
            for (int l = 0; l < featuresTensorData[i][j][k].length; ++l) {
                featuresTensorBuilder.addFloatVal(featuresTensorData[i][j][k][l]);
            }
        }
    }
}


TensorShapeProto.Dim dim1 = TensorShapeProto.Dim.newBuilder().setSize(num_images).build();
TensorShapeProto.Dim dim2 = TensorShapeProto.Dim.newBuilder().setSize(imageHeight).build();
TensorShapeProto.Dim dim3 = TensorShapeProto.Dim.newBuilder().setSize(imageWidth).build();
TensorShapeProto.Dim dim4 = TensorShapeProto.Dim.newBuilder().setSize(3).build();


TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(dim1).addDim(dim2).addDim(dim3).addDim(dim4).build();
featuresTensorBuilder.setTensorShape(featuresShape);
featuresTensorBuilder.setDtype(DataType.DT_FLOAT);      // Notice that DT_INT may cause error(even though you set the placehold's datatype to int int tf_serving.py)

以上代码参考这里进行改写,不同的是该文中传入的是图片名,本文处理的是byte流。思路很清晰,就是解析图片数据,然后写入featuresTensorBuilder中进行传输。

这里需要注意的一点是,最后一行setDtype不能设置成DT_UINT8,我一开始为了让传输效率更高,考虑都是图片数据,直接设置成了uint8格式,结果最后老是调不通,猜测是服务端的支持问题。

服务端的代码如下

import tensorflow as tf

_IMAGE_HEIGHT = 224
_IMAGE_WIDTH = 224
image_raw = tf.placeholder(tf.float32, shape=(None, _IMAGE_HEIGHT, _IMAGE_WIDTH, 3))
x = tf.map_fn(lambda im: image_preprocessing_fn(im, _IMAGE_HEIGHT, _IMAGE_WIDTH), image_raw, dtype=tf.uint8)
output = network_fn(x)

可以看到,由于客户端做了大量处理(图片parse,图片resize等等),服务端的代码更简单了。但这种情况数据传输效率非常低下,因为给客户端造成了很大压力。但它也有好处,就是在客户端可以解析各种格式的图片,比如jpg、png,服务端只管处理图片数据即可。

下面介绍另一种相对高效的方式。

2、string数组的方式
int num_images = 32

for(int i=0; i<num_images; i++){
    final byte[] imgBytes = (byte[]) requestData.get(i);
    tensorProtoBuilder.addStringVal(ByteString.copyFrom(imgBytes));
}

TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(num_images).build();
TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();

tensorProtoBuilder.setDtype(DataType.DT_STRING);
tensorProtoBuilder.setTensorShape(featuresShape);

其实这一段代码网上很容易就能找到,但之后的服务端处理代码就找不到了,我也是想了很久才写出来的(真的很久啊,查了好多资料),如下。

import tensorflow as tf

image_raw = tf.placeholder(tf.string, shape=(None, None))
#x1 = tf.map_fn(lambda im: tf.image.decode_jpeg(tf.reshape(im, [])), image_raw, dtype=tf.uint8)
#x = tf.map_fn(lambda image: image_preprocessing_fn(image, _IMAGE_HEIGHT, _IMAGE_WIDTH), x1, dtype=tf.float32)
x = tf.map_fn(lambda im: image_preprocessing_fn(tf.image.decode_jpeg(tf.reshape(im, [])), _IMAGE_HEIGHT, _IMAGE_WIDTH), image_raw, dtype=tf.float32)
output = network_fn(x)

如代码所示,把多个图片看成string数组,image_raw的shape[0]表示图片数量,shape1表示byte流内容,这样就能读取不同大小的图片了,然后统一放到服务端来resize和预处理。我感觉最有效的方式,往往代码都比较简洁。


PS

文章标题实在不知道写啥好,因为我也不知道该怎么形容,暂时先这样吧。

喜闻乐见的ps时间,本篇博客距离上一篇已经过去快一年时间了,不知道自己都在干嘛,但懒惰是毋庸置疑的。。

不过不管怎样,以后有值得总结记录的东西,还是会写在博客或者github里面的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值