Java importGraphDef()函数仅导入计算图形(由Python代码中的tf.train.write_graph编写),它不加载训练变量的值(存储在检查点中),这就是为什么你会抱怨未初始化变量的错误。
另一方面,TensorFlow SavedModel format包含有关模型的所有信息(图形,检查点状态,其他元数据),并且在Java中使用,您希望使用SavedModelBundle.load创建使用训练变量值初始化的会话。
要从Python导出此格式的模型,您可能需要查看相关问题Deploy retrained inception SavedModel to google cloud ml engine
在您的情况下,这应该类似于Python中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()并通过save_model(session, x, yhat)调用它
然后在Java中加载模型使用:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}希望有所帮助。