tensorflow目前主要的使用语言主要还是python,但是有相当一部分互联网应用是用java开发的,那么java应用如何使用tensorflow开发深度学习相关的功能呢?虽然google开源了tensorflow serving用于生产环境部署训练好的模型,但需要自己实现集群功能和健康检查,同时和java应用中间还隔着一个网络通讯的开销。所以最好还是java应用内部直接调用模型。tensorflow 1.1版本已经推出了java接口,不过我看了一下目前接口数量还是比较少,跟python丰富的各类接口没法比。因此完全使用java接口来构建模型不太现实,而且我估计模型训练效率可能也没python好。另一方面,网上开源的tensorflow模型基本都是用python的,用java重新构建费时费力。基于上述原因,python构建并训练模型+java在线预测是比较合理的方案。
java加载模型进行预测,需要使用jdk8,如果是maven项目的话需要添加下面的依赖:
可以参考官方例子https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java。我的程序中的关键代码如下:
其中constructTensor是自己实现的函数,负责把待检测数据转化成一个Tensor,最后的logits数组是模型的预测值。注意Graph和Session都是线程安全的,只需要单例使用即可。
在python训练代码里,模型训练好以后,要用tf.train.write_graph把整个图的protobuf写到文件中,但是tf.train.write_graph只能保存图的定义和constant参数,variable会被忽略掉,可以使用tf.graph_util.convert_variables_to_constants把variable转成constant再写到文件中,这样学习到的参数就不会丢失。相关代码:
java加载模型进行预测,需要使用jdk8,如果是maven项目的话需要添加下面的依赖:
可以参考官方例子https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java。我的程序中的关键代码如下:
其中constructTensor是自己实现的函数,负责把待检测数据转化成一个Tensor,最后的logits数组是模型的预测值。注意Graph和Session都是线程安全的,只需要单例使用即可。