Java importGraphDef()函数只导入计算图(由Python代码中的tf.train.write_graph编写),它不加载训练变量的值(存储在检查点中),这就是为什么你会得到一个错误抱怨未初始化的变量.
另一方面,TensorFlow SavedModel format包含有关模型(图形,检查点状态,其他元数据)的所有信息,并且在Java中使用,您希望使用SavedModelBundle.load来创建使用训练变量值初始化的会话.
在您的情况下,这应该类似于Python中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
并通过save_model调用它(session,x,yhat)
然后在Java中加载模型使用:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
希望有所帮助.