需要用到tensorflow官方提供的java api,maven依赖如下:
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>proto</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni</artifactId>
<version>1.3.0</version>
</dependency>
</dependencies>
这里使用的是pb格式的模型文件,该模型文件的生成请参考:模型文件保存
测试代码:
import org.tensorflow.*;
import java.util.List;
/**
* Created by ling913.
*/
public class Test {
public static void main(String[] args) {
SavedModelBundle b = SavedModelBundle.load("./src/main/resources/model2", "mytag");
Session tfSession = b.session();
Operation operationPredict = b.graph().operation("predict"); //要执行的op
Output output = new Output(operationPredict, 0);
//构造测试数据,用的是mnist测试集的第15个, mnist.test.images[15],label是数字5
float[][] a = new float[1][784];
a[0] = new float[]{0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.2f,0.517647f,0.839216f,0.992157f,0.996078f,0.992157f,0.796079f,0.635294f,0.160784f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.4f,0.556863f,0.796079f,0.796079f,0.992157f,0.988235f,0.992157f,0.988235f,0.592157f,0.27451f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.996078f,0.992157f,0.956863f,0.796079f,0.556863f,0.4f,0.321569f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.67451f,0.988235f,0.796079f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.0823529f,0.87451f,0.917647f,0.117647f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.478431f,0.992157f,0.196078f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.482353f,0.996078f,0.356863f,0.2f,0.2f,0.2f,0.0392157f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.0823529f,0.87451f,0.992157f,0.988235f,0.992157f,0.988235f,0.992157f,0.67451f,0.321569f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.0823529f,0.839216f,0.992157f,0.796079f,0.635294f,0.4f,0.4f,0.796079f,0.87451f,0.996078f,0.992157f,0.2f,0.0392157f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.239216f,0.992157f,0.670588f,0f,0f,0f,0f,0f,0.0784314f,0.439216f,0.752941f,0.992157f,0.831373f,0.160784f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.4f,0.796079f,0.917647f,0.2f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.0784314f,0.835294f,0.909804f,0.321569f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.243137f,0.796079f,0.917647f,0.439216f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.0784314f,0.835294f,0.988235f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.6f,0.992157f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.160784f,0.913726f,0.831373f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.443137f,0.360784f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.121569f,0.678431f,0.956863f,0.156863f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.321569f,0.992157f,0.592157f,0f,0f,0f,0f,0f,0f,0.0823529f,0.4f,0.4f,0.717647f,0.913726f,0.831373f,0.317647f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.321569f,1.0f,0.992157f,0.917647f,0.596078f,0.6f,0.756863f,0.678431f,0.992157f,0.996078f,0.992157f,0.996078f,0.835294f,0.556863f,0.0784314f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.278431f,0.592157f,0.592157f,0.909804f,0.992157f,0.831373f,0.752941f,0.592157f,0.513726f,0.196078f,0.196078f,0.0392157f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0f,0.0f};
Tensor input_x = Tensor.create(a);
List<Tensor> out = tfSession.runner().feed("input_x", input_x).fetch(output).run();
for (Tensor s : out) {
float[][] t = new float[1][10];
s.copyTo(t);
for (float i : t[0])
System.out.println(i);
}
}
}
输出:
2.5921327E-4
1.7069127E-5
1.0343825E-4
0.015574839
2.9570143E-5
0.9706799
7.140153E-6
7.191204E-5
0.013250019
6.8286636E-6
第6个值最大,说明是数字5