环境准备:
ubuntu20.10 + python3 + tensorflow2.x
参考博客:(5条消息) ubuntu20.10 安装低版本tensorflow1.8 或者 tensorflow2.5 步骤_Navy的博客-CSDN博客 详细描述了tensorflw2.5 环境搭建以及训练模型。
参考博客:(5条消息) ubuntu20.10 tensorflow2.5 将训练后的模型移植到android 平台之官网demo 运行(二)_Navy的博客-CSDN博客 下载调试完整的demo。
1. 将目录:/tensorflow2.0/models/research/object_detection/training_tf2
下的模型数据导出 SavedModel (Exports TF2 detection SavedModel for conversion to TensorFlow Lite)。
#tensorflow2.x
python3 export_tflite_graph_tf2.py --pipeline_config_path=training_tf2/ssd_mobilenet_v2_320x320_coco17_tpu-8.config --trained_checkpoint_dir=training_tf2/ --output_directory=training_tf2/train_export/TFlite
模型数据:
SavedModel 数据:
2. 将 SavedModel 数据通过 pb_to_tf2lite.py 文件转换成 navy_tflite.tflite 文件
pb_to_tf2lite.py :
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
saved_model_dir = '/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/saved_model/'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.experimental_new_converter = True
tflite_model = converter.convert()
open('/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite.tflite', 'wb').write(tflite_model)
3. 将 navy_tflite.tflite 以及 navy_lable.txt 文件通过 metadata_writer_for_object_detection.py 文件转换成 navy_tflite_metadata.tflite 文件
navy_lable.txt :你所训练的物体,将物体都列出来
cup
curtain
metadata_writer_for_object_detection.py :
"""Writes metadata and label file to the object_detection models."""
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_lable.txt"
_SAVE_TO_PATH = "/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite_metadata.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5
# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
[_LABEL_FILE])
# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())
# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)
下面着重讲一下为什么要产生navy_tflite_metadata.tflite 这个文件:
metadata 一般指元数据。
navy_tflite_metadata.tflite = navy_tflite.tflite + navy_lable.txt (Writes metadata and label file to the object_detection models)简单来说就是将 lable 文件存储在 TFLite 模型模式的 metadata 字段中。
对于官方demo 来说,如果.tflite 文件如果没有添加lable 运行时将会出错:java.lang.IllegalStateException: This model does not contain associated files, and is not a Zip file.
详细信息请查询:https://tensorflow.google.cn/lite/convert/metadata_writer_tutorial?hl=en
补充: 对于图像分类(image_classification)使用:metadata_writer_for_image_classifier.py
路径:examples-master\lite\examples\image_classification\metadata
"""Writes metadata and label file to the image classifier models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
import tensorflow as tf
import flatbuffers
# pylint: disable=g-direct-tensorflow-import
from tflite_support import metadata_schema_py_generated as _metadata_fb
from tflite_support import metadata as _metadata
# pylint: enable=g-direct-tensorflow-import
FLAGS = flags.FLAGS
def define_flags():
flags.DEFINE_string("model_file", None,
"Path and file name to the TFLite model file.")
flags.DEFINE_string("label_file", None, "Path to the label file.")
flags.DEFINE_string("export_directory", None,
"Path to save the TFLite model files with metadata.")
flags.mark_flag_as_required("model_file")
flags.mark_flag_as_required("label_file")
flags.mark_flag_as_required("export_directory")
class ModelSpecificInfo(object):
"""Holds information that is specificly tied to an image classifier."""
def __init__(self, name, version, image_width, image_height, image_min,
image_max, mean, std, num_classes, author):
self.name = name
self.version = version
self.image_width = image_width
self.image_height = image_height
self.image_min = image_min
self.image_max = image_max
self.mean = mean
self.std = std
self.num_classes = num_classes
self.author = author
_MODEL_INFO = {
"mobilenet_v1_0.75_160_quantized.tflite":
ModelSpecificInfo(
name="MobileNetV1 image classifier",
version="v1",
image_width=160,
image_height=160,
image_min=0,
image_max=255,
mean=[127.5],
std=[127.5],
num_classes=1001,
author="TensorFlow")
}
class MetadataPopulatorForImageClassifier(object):
"""Populates the metadata for an image classifier."""
def __init__(self, model_file, model_info, label_file_path):
self.model_file = model_file
self.model_info = model_info
self.label_file_path = label_file_path
self.metadata_buf = None
def populate(self):
"""Creates metadata and then populates it for an image classifier."""
self._create_metadata()
self._populate_metadata()
def _create_metadata(self):
"""Creates the metadata for an image classifier."""
# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = self.model_info.name
model_meta.description = ("Identify the most prominent object in the "
"image from a set of %d categories." %
self.model_info.num_classes)
model_meta.version = self.model_info.version
model_meta.author = self.model_info.author
model_meta.license = ("Apache License. Version 2.0 "
"http://www.apache.org/licenses/LICENSE-2.0.")
# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()
input_meta.name = "image"
input_meta.description = (
"Input image to be classified. The expected image is {0} x {1}, with "
"three channels (red, blue, and green) per pixel. Each value in the "
"tensor is a single byte between {2} and {3}.".format(
self.model_info.image_width, self.model_info.image_height,
self.model_info.image_min, self.model_info.image_max))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
_metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = self.model_info.mean
input_normalization.options.std = self.model_info.std
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [self.model_info.image_max]
input_stats.min = [self.model_info.image_min]
input_meta.stats = input_stats
# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the %d labels respectively." % self.model_info.num_classes
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
_metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename(self.label_file_path)
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]
# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]
b = flatbuffers.Builder(0)
b.Finish(
model_meta.Pack(b),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
self.metadata_buf = b.Output()
def _populate_metadata(self):
"""Populates metadata and label file to the model file."""
populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
populator.load_metadata_buffer(self.metadata_buf)
populator.load_associated_files([self.label_file_path])
populator.populate()
def main(_):
model_file = FLAGS.model_file
model_basename = os.path.basename(model_file)
if model_basename not in _MODEL_INFO:
raise ValueError(
"The model info for, {0}, is not defined yet.".format(model_basename))
export_model_path = os.path.join(FLAGS.export_directory, model_basename)
# Copies model_file to export_path.
tf.io.gfile.copy(model_file, export_model_path, overwrite=False)
# Generate the metadata objects and put them in the model file
populator = MetadataPopulatorForImageClassifier(
export_model_path, _MODEL_INFO.get(model_basename), FLAGS.label_file)
populator.populate()
# Validate the output model file by reading the metadata and produce
# a json file with the metadata under the export path
displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path)
export_json_file = os.path.join(FLAGS.export_directory,
os.path.splitext(model_basename)[0] + ".json")
json_file = displayer.get_metadata_json()
with open(export_json_file, "w") as f:
f.write(json_file)
print("Finished populating metadata and associated file to the model:")
print(model_file)
print("The metadata json file has been saved to:")
print(export_json_file)
print("The associated file that has been been packed to the model is:")
print(displayer.get_packed_associated_file_list())
if __name__ == "__main__":
define_flags()
app.run(main)
4. 将 navy_tflite_metadata.tflite以及 navy_lable.txt 文件复制到官方demo 的 src\main\assets。将demo 中 DetectorActivity.java 以及 DetectorTest.java 中两处加载 lite-model_ssd_mobilenet_v1_1_metadata_2.tflite 以及 labelmap.txt 地方修改成 navy_tflite_metadata.tflite以及 navy_lable.txt 。
点击安装运行app,app 闪退。logcat 出现错误:java.lang.IllegalArgumentException: Cannot copy to a TensorFlowLite tensor (serving_default_input:0) with 1080000 bytes from a Java Buffer with 270000 bytes.
解决方法:将demo 中 DetectorActivity.java 以及 DetectorTest.java 中两处定义修改参数为false:
#DetectorActivity.java
private static final boolean TF_OD_API_IS_QUANTIZED = false;
#DetectorTest.java
private static final boolean IS_MODEL_QUANTIZED = false;
这里就涉及到 tensor 接收图片的大小以及量化问题:
查看.tflite 文件的输入输出详细类型信息:
import tensorflow as tf
model_path = '/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite.tflite'
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# 获取输入和输出张量。
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("==================input===================")
print(input_details)
print("==================output==================")
print(output_details)
==================input===================
[{'name': 'serving_default_input:0', 'index': 0, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'shape_signature': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
==================output==================
[{'name': 'StatefulPartitionedCall:3', 'index': 247, 'shape': array([ 1, 10, 4], dtype=int32), 'shape_signature': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:2', 'index': 248, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:1', 'index': 249, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:0', 'index': 250, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
详细分析.tflite 文件run 时需要输入输出的数据类型:
input :一个节点表示只有一个输入
节点名称:serving_default_input:0
输入形状:'shape': array([ 1, 300, 300, 3] 表示图像的大小是300x300 3通道
因此在对图片进行识别的时候,要将图片放缩到300x300 ,再交给.tflite 文件进行处理。
数据类型:'dtype': <class 'numpy.float32'> 表示数据类型,表示浮点量化
数据类型关系到图像资料的处理:不同的量化需要不同的处理方式。
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
int pixelValue = intValues[i * inputSize + j];
if (isModelQuantized) {
// Quantized model
imgData.put((byte) ((pixelValue >> 16) & 0xFF));
imgData.put((byte) ((pixelValue >> 8) & 0xFF));
imgData.put((byte) (pixelValue & 0xFF));
} else { // Float model
imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
output: 四个节点,表示四个输出
节点名称0:StatefulPartitionedCall:3 array([ 1, 10, 4], dtype=int32) 对象检测的矩形框坐标
节点名称1:StatefulPartitionedCall:2 array([ 1, 10 ], dtype=int32) 对象检测的索引号
节点名称2:StatefulPartitionedCall:1 array([ 1, 10 ], dtype=int32) 对象检测的概率
节点名称3:StatefulPartitionedCall:0 array([ 1 ], dtype=int32) 对象检测的物体总和
总结:使用.tflite 文件的时候,要先查看量化类型,在demo 中做相应的处理。