文章目录
对于tensorflow serving部署模型,网上最常见的是基于python的接口教程。针对非python的情况,例如java,单张图片的inference教程也很多,然而一次处理多张图片的教程则较少。这里结合网上的资料,以及自己的尝试,着重介绍java下多张图片的处理方式。
一、单张图片预测
整个流程分为两部分,客户端传入图片数据,tensorflow serving端解析数据,详细代码比较多,这里只介绍最核心的数据处理部分。
1、客户端
import org.tensorflow.framework.TensorProto;
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.addStringVal(ByteString.copyFrom(imgBytes));
tensorProtoBuilder.setDtype(DataType.DT_STRING);
客户端是java代码,核心代码如上,主要是构造这个TensorProto的数据结构,并把图片数据放进去。其中的imgBytes是图片数据对应的byte数组,所以最终数据结构设置成了DT_STRING,后面会看到它也支持其它数据格式。其实这类proto格式的数据在各种语言下都可以使用,毕竟谷歌的protobuf是一种通用的数据格式。
2、服务端
import tensorflow as tf
image_raw = tf.placeholder(tf.string)
x = tf.image.decode_jpeg(image_raw)
# image pre-process
output = network_fn(x)
服务端则是tensorflow serving的一些具体代码了,可以在网上搜到相关教程。这里介绍如何处理客户端传来的图片流数据,这里很简单,直接decode即可。
以上就是单张图片的处理了,可以看到是非常简单的。然而为了提升效率,线上部署往往需要送入一个batch的图片,这个时候该怎么写,后面废了我好大劲才搞清楚。
二、多张图片预测(batching)
关于多张图片的预测,可以分为客户端batching和服务端batching,可以简单看看这里。针对服务端的batching,在python下只需设置一个参数即可,但在不同图片尺寸情况下,这种batching会自动把所有图片padding到统一尺寸,这不是我想要的方式。另一种则是客户端batching,操作起来相对更灵活。
下面介绍两种客户端batching的方式,也是参考了这里的回答
1、解析图片数据,制作proto的方式
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import java.awt.image.BufferedImage;
import javax.imageio.ImageIO;
import java.io.ByteArrayInputStream;
import org.tensorflow.framework.TensorShapeProto;
int imageHeight = 224;
int imageWidth = 224;
int num_images = 32;
// 注意需要将图片resize到统一尺寸,这里就不细写了
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
int[][][][] featuresTensorData = new int[num_images][imageHeight][imageWidth][3];
for(int i=0 ;i<num_images; i++) {
final byte[] imgBytes = (byte[]) requestData.get(i);
try {
BufferedImage image = ImageIO.read(new ByteArrayInputStream(imgBytes));
int[][] imageArray = new int[imageHeight][imageWidth];
for (int row = 0; row < imageHeight; row++) {
for (int column = 0; column < imageWidth; column++) {
imageArray[row][column] = image.getRGB(column, row);
int pixel = image.getRGB(column, row);
featuresTensorData[i][row][column][0] = (pixel >> 16) & 0xff; // red
featuresTensorData[i][row][column][1] = (pixel >> 8) & 0xff; // green
featuresTensorData[i][row][column][2] = pixel & 0xff; // blue
}
}
} catch (IOException e) {
e.printStackTrace();
System.exit(1);
}
}
for (int i = 0; i < featuresTensorData.length; ++i) {
for (int j = 0; j < featuresTensorData[i].length; ++j) {
for (int k = 0; k < featuresTensorData[i][j].length; ++k) {
for (int l = 0; l < featuresTensorData[i][j][k].length; ++l) {
featuresTensorBuilder.addFloatVal(featuresTensorData[i][j][k][l]);
}
}
}
}
TensorShapeProto.Dim dim1 = TensorShapeProto.Dim.newBuilder().setSize(num_images).build();
TensorShapeProto.Dim dim2 = TensorShapeProto.Dim.newBuilder().setSize(imageHeight).build();
TensorShapeProto.Dim dim3 = TensorShapeProto.Dim.newBuilder().setSize(imageWidth).build();
TensorShapeProto.Dim dim4 = TensorShapeProto.Dim.newBuilder().setSize(3).build();
TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(dim1).addDim(dim2).addDim(dim3).addDim(dim4).build();
featuresTensorBuilder.setTensorShape(featuresShape);
featuresTensorBuilder.setDtype(DataType.DT_FLOAT); // Notice that DT_INT may cause error(even though you set the placehold's datatype to int int tf_serving.py)
以上代码参考这里进行改写,不同的是该文中传入的是图片名,本文处理的是byte流。思路很清晰,就是解析图片数据,然后写入featuresTensorBuilder中进行传输。
这里需要注意的一点是,最后一行setDtype不能设置成DT_UINT8,我一开始为了让传输效率更高,考虑都是图片数据,直接设置成了uint8格式,结果最后老是调不通,猜测是服务端的支持问题。
服务端的代码如下
import tensorflow as tf
_IMAGE_HEIGHT = 224
_IMAGE_WIDTH = 224
image_raw = tf.placeholder(tf.float32, shape=(None, _IMAGE_HEIGHT, _IMAGE_WIDTH, 3))
x = tf.map_fn(lambda im: image_preprocessing_fn(im, _IMAGE_HEIGHT, _IMAGE_WIDTH), image_raw, dtype=tf.uint8)
output = network_fn(x)
可以看到,由于客户端做了大量处理(图片parse,图片resize等等),服务端的代码更简单了。但这种情况数据传输效率非常低下,因为给客户端造成了很大压力。但它也有好处,就是在客户端可以解析各种格式的图片,比如jpg、png,服务端只管处理图片数据即可。
下面介绍另一种相对高效的方式。
2、string数组的方式
int num_images = 32
for(int i=0; i<num_images; i++){
final byte[] imgBytes = (byte[]) requestData.get(i);
tensorProtoBuilder.addStringVal(ByteString.copyFrom(imgBytes));
}
TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(num_images).build();
TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
tensorProtoBuilder.setDtype(DataType.DT_STRING);
tensorProtoBuilder.setTensorShape(featuresShape);
其实这一段代码网上很容易就能找到,但之后的服务端处理代码就找不到了,我也是想了很久才写出来的(真的很久啊,查了好多资料),如下。
import tensorflow as tf
image_raw = tf.placeholder(tf.string, shape=(None, None))
#x1 = tf.map_fn(lambda im: tf.image.decode_jpeg(tf.reshape(im, [])), image_raw, dtype=tf.uint8)
#x = tf.map_fn(lambda image: image_preprocessing_fn(image, _IMAGE_HEIGHT, _IMAGE_WIDTH), x1, dtype=tf.float32)
x = tf.map_fn(lambda im: image_preprocessing_fn(tf.image.decode_jpeg(tf.reshape(im, [])), _IMAGE_HEIGHT, _IMAGE_WIDTH), image_raw, dtype=tf.float32)
output = network_fn(x)
如代码所示,把多个图片看成string数组,image_raw的shape[0]表示图片数量,shape1表示byte流内容,这样就能读取不同大小的图片了,然后统一放到服务端来resize和预处理。我感觉最有效的方式,往往代码都比较简洁。
PS
文章标题实在不知道写啥好,因为我也不知道该怎么形容,暂时先这样吧。
喜闻乐见的ps时间,本篇博客距离上一篇已经过去快一年时间了,不知道自己都在干嘛,但懒惰是毋庸置疑的。。
不过不管怎样,以后有值得总结记录的东西,还是会写在博客或者github里面的。