tensorflow的python离线训练java在线预测方案

tensorflow目前主要的使用语言主要还是python,但是有相当一部分互联网应用是用java开发的,那么java应用如何使用tensorflow开发深度学习相关的功能呢?虽然google开源了tensorflow serving用于生产环境部署训练好的模型,但需要自己实现集群功能和健康检查,同时和java应用中间还隔着一个网络通讯的开销。所以最好还是java应用内部直接调用模型。tensorflow 1.1版本已经推出了java接口,不过我看了一下目前接口数量还是比较少,跟python丰富的各类接口没法比。因此完全使用java接口来构建模型不太现实,而且我估计模型训练效率可能也没python好。另一方面,网上开源的tensorflow模型基本都是用python的,用java重新构建费时费力。基于上述原因,python构建并训练模型+java在线预测是比较合理的方案。

在python训练代码里,模型训练好以后,要用tf.train.write_graph把整个图的protobuf写到文件中,但是tf.train.write_graph只能保存图的定义和constant参数,variable会被忽略掉,可以使用tf.graph_util.convert_variables_to_constants把variable转成constant再写到文件中,这样学习到的参数就不会丢失。相关代码:

graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output/logits"])
tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

java加载模型进行预测,需要使用jdk8,如果是maven项目的话需要添加下面的依赖:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.1.0</version>
</dependency>

可以参考官方例子https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java。我的程序中的关键代码如下:

String modelDir = ".";
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "graph.pb"));
Graph g = new Graph();
g.importGraphDef(graphDef);
Session s = new Session(g);
Tensor input = constructTensor(data);
Tensor result = s.runner().feed("input", input).fetch("output/logits").run().get(0);
long[] rshape = result.shape();
int nlabels = (int) rshape[1];
int batchSize = (int) rshape[0];
float[][] logits = result.copyTo(new float[batchSize][nlabels]);

其中constructTensor是自己实现的函数,负责把待检测数据转化成一个Tensor,最后的logits数组是模型的预测值。注意Graph和Session都是线程安全的,只需要单例使用即可。

没有更多推荐了,返回首页