Java调用Tensorflow训练好的模型做预测,首先需要读取词典,然后加载模型,读入数据,最后预测结果。
模型训练参考上一篇博客:使用Tensorflow训练LSTM+Attention中文标题党分类
首先需要下载一些包,如果是maven项目在pom.xml中添加两个依赖。
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.5.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni</artifactId>
<version>1.5.0</version>
</dependency>
读取词典文件
这个词典文件wordIndexMap.txt,就是上一篇对应训练模型之前生成的词典文件。每行一个词和词的编号。
// 从文件读取词典文件存入Map
private static Map<String, Integer> readVocabFromFile(String pathname) throws IOException{
Map<String, Integer> wordMap = new HashMap<String, Integer>();
File filename = new File(pathname);
InputStreamReader reader = new InputStreamReader(new FileInputStream(filename));
BufferedReader br = new BufferedReader(reader);
String line = "";
line = br.readLine();
String[] lineArray;
while(line != null){
lineArray = line.split(" "