TensorFlow for Java
WARNING: The TensorFlow Java API is not currently covered by the TensorFlow API stability guarantees.
目前,TensorFlow Java API 不在 TensorFlow API 稳定性保证的范围内。
For using TensorFlow on Android refer instead to TensorFlow Lite.
关于在Android上使用TensorFlow,请参考TensorFlow Lite。
使用 Java API 进行预测
private static Session loadSession() {
SimpleDateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");
Graph graph = new Graph(); //创建图结构
InputStream is = getStreamFromPb("car_model.pb"); //加载本地 pb 文件到内存
byte[] graphBytes = new byte[0];
try {
graphBytes = IOUtils.toByteArray(is);
} catch (Exception e) {
e.printStackTrace();
}
graph.importGraphDef(graphBytes); //内存数据 ——> 图结构
Session session = new Session(graph); //通过图结构初始化会话
return session;
}
private static String faceEmbedding(Session session, String imagePath) {
float[][] embeddingsRes = new float[1][128];
try {
float[][][] rgbImage = readImage(imagePath);
float[][][] rgbWhitened = whiten(rgbImage);
float[][][][] rgbFloat = new float[1][160][160][3];
rgbFloat[0] = rgbWhitened;
Tensor<Float> imageTensor = Tensors.create(rgbFloat); //输入 Tensor
Tensor phaseTensor = Tensor.create(new Boolean(Boolean.FALSE)); //输入 Tensor
Session.Runner result = session.runner().feed("input", imageTensor).feed("phase_train", phaseTensor);
Tensor embeddings = result.fetch("embeddings").run().get(0); //执行图,输出 Tensor
System.out.println("embeddings.toString(): " + embeddings.toString());
embeddings.copyTo(embeddingsRes);
} catch (Exception e) {
e.printStackTrace();
}
JSONObject json = new JSONObject();
json.put(image_path, embeddingsRes[0]);
return json.toString();
}
参考:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java
https://www.jianshu.com/p/e11891418bc1