我用tf2目标检测api训练的savedmodel转可以在安卓上应用的tflite模型。
api导出savedmodel模型
我查阅在android下得tflite task库使用tflite模型进行nnapi加速或者应用tflite模型必须得进行matedata元数据得写入,在转化得savedmodel模型必须得符合规定得4个输出。具体的不列出来了,在api下的object_detection目录下的export_tflite_graph_tf2.py可以进行这种形式的导出。
savedmodel转tflite
转化代码要在tf2版本下进行。
import os
import tensorflow as tf
try:
def representative_dataset():
data_folder = "C:/Users/PaulY/Desktop/mp4img"
data_files = os.listdir(data_folder)
for file in data_files:
image_path = os.path.join(data_folder, file)
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [320, 320])
image = tf.cast(image, tf.float32) / 255.0
image = tf.expand_dims(image, axis=0)
yield [image]
converter = tf.lite.TFLiteConverter.from_saved_model("C:/Users/PaulY/Desktop/tf2/litemodelpb/saved_model")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
# Ensure that if any ops can't be quantized, the converter throws an error
# tf.lite.OpsSet.TFLITE_BUILTINS_INT8,, tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
# converter.inference_output_type = tf.uint8
# converter.allow_custom_ops = True
# If you want to quantize the model, set it True; else, set it False
converter.post_training_quantize = True
tflite_model_quant = converter.convert()
with open('nomatedatamobilenetv1.tflite','wb') as g:
g.write(tflite_model_quant)
print('tflite export success')
except Exception as e:
print('tflite export failure: %s' % e)
上述代码中,首先加载你的 TensorFlow 模型,然后创建一个 TFLite 转换器。接下来,我们针对量化和 uint8 输入进行了一些设置:
converter.optimizations:用于设置转换过程中的优化选项,这里使用的是默认优化设置。
converter.target_spec.supported_ops:设置支持的操作集为 TFLITE_BUILTINS_INT8,表示将使用量化操作进 行模型压缩。
这里还可以设置tf.lite.OpsSet.TFLITE_BUILTINS 和 tf.lite.OpsSet.SELECT_TF_OPS 。这个参数是用于设置 TFLite 转换过程中所使用的操作集的参数。
tf.lite.OpsSet.TFLITE_BUILTINS 表示使用 TFLite 内置操作集。TFLite 内置操作集是在 TFLite 框架中实现的一组基本操作,适用于大多数应用场景。它包含了常见的运算操作,如卷积、池化、全连接等。
tf.lite.OpsSet.SELECT_TF_OPS 表示选择 TensorFlow 的操作集。该操作集允许使用 TensorFlow 中的所有操作,在转换过程中能够更好地保留原始模型的完整功能。但请注意,这可能会导致生成的 TFLite 模型较大,并且可能不适用于某些资源受限的环境。
通常情况下,如果你的模型仅使用了基本操作或者你希望最小化模型大小和资源消耗,建议使用 tf.lite.OpsSet.TFLITE_BUILTINS。如果你的模型使用了一些特定的 TensorFlow 操作,并且你需要保留这些操作的功能性,可以选择 tf.lite.OpsSet.SELECT_TF_OPS。这里我使用ssdmobilenet模型官方支持模型可以选次参数。
converter.inference_input_type:设置输入数据类型为 uint8。
converter.inference_output_type:设置输出数据类型为 float32。可以不进行设置。
提供一个代表性数据集函数 representative_dataset() 用于校准和量化模型。你可以在该函数中生成或加载适合你的数据集,并返回一个形状符合模型输入要求的数据。最后,将模型转换为 TFLite 格式并保存到文件中。这步为校准。如果不进行这个步骤或使用随机生成的数据会导致生成的tflite模型无法工作。这也是影响精度的重要的一步。
启用 post_training_quantize 可以触发训练后量化过程,将浮点数模型转换为定点数模型,进一步优化模型的大小和性能。通过量化,可以减小模型的存储需求并提高推理速度。
请确保在执行此操作之前,已经完成了模型的训练,并保存了 TensorFlow SavedModel 格式的模型。
为转换的模型添加元数据
这步我们需要下载tflite_support库,pip就可以。用官方的转化必须是你使用的模型是官方的模型。
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "nomatedatamobilenetv1.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "labels.txt"
_SAVE_TO_PATH = "mobilenetv1.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/models/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5
# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
[_LABEL_FILE])
# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())
# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)