A/native: tensor.cc:487 Check failed: dtype() == expected_dtype (9 vs. 3)
原因解释
出现这种错误的原因是:程序预期的数据是tf.int32
,然而真正的数据是tf.int64
。在我的应用中,出现以上错误的程序如下。
在TF模型中,有如下的定义。
self.predictions = tf.argmax(self.logits, 1, name="predictions")
这里,函数tf.argmax
返回的Tensor的类型是tf.int64
。
在进行了freeze_graph
和optimize_for_inference
之后,我在Android app中按照如下方式调用模型进行了预测:
tensorFlowInferenceInterface = new TensorFlowInferenceInterface();
tensorFlowInferenceInterface.initializeTensorFlow(assetManager,MODEL_FILE);