我使用Tensorflow
Java Api将已创建的Tensorflow模型加载到JVM中.
这是我的简单scala代码:
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
如何保存模型以使Session和Graph存储在同一文件中.如上面的“PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb”中所述.
描述here它提到:
The serialized representation of the graph, often referred to as a
GraphDef, can be generated by toGraphDef() and equivalents in other
language APIs.
其他语言API的等价物是什么?我觉得很明显
注意:我已经在tensorflow_serving下查看了mnist_saved_model.py,但通过该过程保存它会给我一个.pb文件和一个变量文件夹.当我尝试加载.pb文件时,我得到:java.lang.IllegalArgumentException:无效的GraphDef
最佳答案 目前使用tensorflow的Java API,我只发现了如何将图形保存为graphDef(即没有其变量和元数据).这可以通过将Array [Byte]写入文件来完成:
Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
这里myGraph是Graph class的java对象.
我建议使用此处定义的SavedModel api从Python API保存模型.它会将模型保存在一个文件夹中,该文件夹包含.pb文件中的序列化图形和文件夹中的变量.请注意您在scala / java代码中使用的tag_constants,以便使用变量加载模型.然后使用java api中的SavedModelBundle java类轻松加载带变量的图形和会话.它返回一个包含图形和包含变量值的会话的包装器:
val model = SavedModelBundle.load(modelDir, modelTag)
如果你已经尝试过这个,也许你可以分享你的代码,看看为什么它返回了一个无效的GraphDef.
另一个选项是冻结图形,即您将变量节点变为常量节点,因此.pb文件中的所有内容都是自包含的. Mores infos here为冷冻部分