tensorflow的python离线训练java在线预测方案

tensorflow目前主要的使用语言主要还是python,但是有相当一部分互联网应用是用java开发的,那么java应用如何使用tensorflow开发深度学习相关的功能呢?虽然google开源了tensorflow serving用于生产环境部署训练好的模型,但需要自己实现集群功能和健康检查,同时和java应用中间还隔着一个网络通讯的开销。所以最好还是java应用内部直接调用模型。tensorflow 1.1版本已经推出了java接口,不过我看了一下目前接口数量还是比较少,跟python丰富的各类接口没法比。因此完全使用java接口来构建模型不太现实,而且我估计模型训练效率可能也没python好。另一方面,网上开源的tensorflow模型基本都是用python的,用java重新构建费时费力。基于上述原因,python构建并训练模型+java在线预测是比较合理的方案。

在python训练代码里,模型训练好以后,要用tf.train.write_graph把整个图的protobuf写到文件中,但是tf.train.write_graph只能保存图的定义和constant参数,variable会被忽略掉,可以使用tf.graph_util.convert_variables_to_constants把variable转成constant再写到文件中,这样学习到的参数就不会丢失。相关代码:

[python]  view plain  copy
  1. graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output/logits"])  
  2. tf.train.write_graph(graph, '.''graph.pb', as_text=False)  

java加载模型进行预测,需要使用jdk8,如果是maven项目的话需要添加下面的依赖:

[html]  view plain  copy
  1. <dependency>  
  2.   <groupId>org.tensorflow</groupId>  
  3.   <artifactId>tensorflow</artifactId>  
  4.   <version>1.1.0</version>  
  5. </dependency>  

可以参考官方例子https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java。我的程序中的关键代码如下:

[java]  view plain  copy
  1. String modelDir = ".";  
  2. byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "graph.pb"));  
  3. Graph g = new Graph();  
  4. g.importGraphDef(graphDef);  
  5. Session s = new Session(g);  
  6. Tensor input = constructTensor(data);  
  7. Tensor result = s.runner().feed("input", input).fetch("output/logits").run().get(0);  
  8. long[] rshape = result.shape();  
  9. int nlabels = (int) rshape[1];  
  10. int batchSize = (int) rshape[0];  
  11. float[][] logits = result.copyTo(new float[batchSize][nlabels]);  

其中constructTensor是自己实现的函数,负责把待检测数据转化成一个Tensor,最后的logits数组是模型的预测值。注意Graph和Session都是线程安全的,只需要单例使用即可。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值