添加metadata到tflite模型

1、metadata简介

2、metadata工具

3、过程

1、metadata简介

TensorFlow Lite 元数据提供了模型描述的标准。 元数据是关于模型做什么及其输入/输出信息的重要知识来源。 元数据主要由人类可读部分和机器可读部分组成。

注: 在 TensorFlow Lite 托管模型和 TensorFlow Hub 上发布的所有图像模型都已添加元数据。

metadata主要包括三部分:

(1)模型信息:模型的总体描述以及许可条款等项。
(2)输入信息:对输入和所需的预处理(如规范化)的描述。
(3)输出信息:描述输出和所需的后处理,如映射到标签。


对于输入和输出,TensorFlow Lite metadata 在设计时并没有考虑特定的模型类型,而是考虑输入和输出类型。不管模型在功能上做什么,只要输入和输出类型包含以下内容或以下内容的组合,TensorFlow Lite metadata 就支持它:

(1) 特征:无符号整数或浮点数32。
(2) 图像:metadata 当前支持 RGB 和灰度图像。
(3) 边界框:矩形形状边界框。该模式支持多种编号方案。

2、metadata 工具

tflite-support

3、过程

(1) 首先获取metadata工具

pip install tflite-support

(2) 利用tflite-support添加metadata,目前tflite官方已经支持几种相关任务的,可以直接调用包进行添加。

 其具体源码链接为:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/convert/metadata_writer_tutorial.ipynb

(3)对于官方不能直接支持的,需要使用 Flatbuffers Python API

以下为针对姿态估计任务为tflite模型添加metadata:

import os
from absl import app
from absl import flags
import tensorflow as tf
import flatbuffers
from tflite_support import metadata_schema_py_generated as _metadata_fb
from tflite_support import metadata as _metadata


FLAGS = flags.FLAGS


def define_flags():
  flags.DEFINE_string("model_file", None,
                      "Path and file name to the TFLite model file.")
  flags.DEFINE_string("label_file", None, "Path to the label file.")
  flags.DEFINE_string("export_directory", None,
                      "Path to save the TFLite model files with metadata.")
  flags.mark_flag_as_required("model_file")
  flags.mark_flag_as_required("label_file")
  flags.mark_flag_as_required("export_directory")


class ModelSpecificInfo(object):
  """Holds information that is specificly tied to an pose estimation."""

  def __init__(self, name, version, image_width, image_height, image_min,
               image_max, num_classes, author):
    self.name = name
    self.version = version
    self.image_width = image_width
    self.image_height = image_height
    self.image_min = image_min
    self.image_max = image_max
    # self.mean = mean
    # self.std = std
    self.num_classes = num_classes
    self.author = author


_MODEL_INFO = {
    "pose_movenet_thunder.tflite":
        ModelSpecificInfo(
            name="pose estimation",
            version="v1",
            image_width=192,
            image_height=192,
            image_min=0,
            image_max=255,
            num_classes=6,
            author="TensorFlow Lite")
}


class MetadataPopulatorForPoseEstimation(object):
  """Populates the metadata for an pose estimation."""

  def __init__(self, model_file, model_info, label_file_path):
    self.model_file = model_file
    self.model_info = model_info
    self.label_file_path = label_file_path
    self.metadata_buf = None

  def populate(self):
    """Creates metadata and then populates it for an pose rstimation."""
    self._create_metadata()
    self._populate_metadata()

  def _create_metadata(self):
    """Creates the metadata for an pose estimation."""

    # Creates model info.
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.name = self.model_info.name
    model_meta.description = ("Identify the 17 bone joints of the characters in the picture and draw them,and estimate the action based on the joint point coordinates")
    model_meta.version = self.model_info.version
    model_meta.author = self.model_info.author
    model_meta.license = ("Apache License. Version 2.0 "
                          "http://www.apache.org/licenses/LICENSE-2.0.")

    # Creates input info.
    input_meta = _metadata_fb.TensorMetadataT()
    input_meta.name = "image"
    input_meta.description = (
        "Input image to be identify. The expected image is {0} x {1}, with "
        "three channels (red, blue, and green) per pixel. Each value in the "
        "tensor is a single byte between {2} and {3}.".format(
            self.model_info.image_width, self.model_info.image_height,
            self.model_info.image_min, self.model_info.image_max))
    input_meta.content = _metadata_fb.ContentT()
    input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
    input_meta.content.contentProperties.colorSpace = (
        _metadata_fb.ColorSpaceType.RGB)
    input_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.ImageProperties)
    input_stats = _metadata_fb.StatsT()
    input_stats.max = [self.model_info.image_max]
    input_stats.min = [self.model_info.image_min]
    input_meta.stats = input_stats

    # Creates output info.
    output_meta = _metadata_fb.TensorMetadataT()
    output_meta.name = "probability"
    output_meta.description = "Probabilities of the %d labels respectively." % self.model_info.num_classes
    output_meta.content = _metadata_fb.ContentT()
    output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
    output_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.FeatureProperties)
    output_stats = _metadata_fb.StatsT()
    output_stats.max = [1.0]
    output_stats.min = [0.0]
    output_meta.stats = output_stats
    label_file = _metadata_fb.AssociatedFileT()
    label_file.name = os.path.basename(self.label_file_path)
    label_file.description = "Labels for actions that the model can detect."
    label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
    output_meta.associatedFiles = [label_file]

    # Creates subgraph info.
    subgraph = _metadata_fb.SubGraphMetadataT()
    subgraph.inputTensorMetadata = [input_meta]
    subgraph.outputTensorMetadata = [output_meta]
    model_meta.subgraphMetadata = [subgraph]

    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    self.metadata_buf = b.Output()

  def _populate_metadata(self):
    """Populates metadata and label file to the model file."""
    populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
    populator.load_metadata_buffer(self.metadata_buf)
    populator.load_associated_files([self.label_file_path])
    populator.populate()


def main(_):
  model_file = FLAGS.model_file
  model_basename = os.path.basename(model_file)
  if model_basename not in _MODEL_INFO:
    raise ValueError(
        "The model info for, {0}, is not defined yet.".format(model_basename))
  export_model_path = os.path.join(FLAGS.export_directory, model_basename)
  # Copies model_file to export_path.
  tf.io.gfile.copy(model_file, export_model_path, overwrite=False)
  # Generate the metadata objects and put them in the model file
  populator = MetadataPopulatorForPoseEstimation(
      export_model_path, _MODEL_INFO.get(model_basename), FLAGS.label_file)
  populator.populate()

  # Validate the output model file by reading the metadata and produce
  # a json file with the metadata under the export path
  displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path)
  export_json_file = os.path.join(FLAGS.export_directory,os.path.splitext(model_basename)[0] + ".json")
  json_file = displayer.get_metadata_json()
  with open(export_json_file, "w") as f:
    f.write(json_file)

  print("Finished populating metadata and associated file to the model:")
  print(model_file)
  print("The metadata json file has been saved to:")
  print(export_json_file)
  print("The associated file that has been been packed to the model is:")
  print(displayer.get_packed_associated_file_list())


if __name__ == "__main__":
  define_flags()
  app.run(main)

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是将TensorFlow 1.15模型转换为TFLITE模型的详细代码和操作步骤: 1. 安装TensorFlow 1.15和TFLITE 在终端中执行以下命令来安装TensorFlow 1.15和TFLITE: ``` pip install tensorflow==1.15 pip install tensorflow==1.15-gpu pip install tensorflow==1.15-tflite ``` 或者使用以下命令: ``` pip install tensorflow==1.15 tensorflow-gpu==1.15 tensorflow-tensorboard==1.15 tensorflow-estimator==1.15 tensorflow-addons==0.10.0 tensorflow-datasets==3.0.0 tensorflow-hub==0.7.0 tensorflow-metadata==0.25.0 tensorflow-probability==0.7.0 tensorflow-serving-api==1.15.0 tensorflow-transform==0.15.0 tensorflow-io==0.11.0 pip install tensorflow==1.15-tflite ``` 2. 加载TensorFlow模型Python脚本中,使用以下代码加载TensorFlow模型: ``` import tensorflow as tf # Load the TensorFlow model model = tf.keras.models.load_model('path/to/the/model') ``` 3. 将TensorFlow模型转换为TFLITE模型 使用以下代码将TensorFlow模型转换为TFLITE模型: ``` # Convert the TensorFlow model to TFLITE converter = tf.lite.TFLiteConverter.from_keras_model(model) tflite_model = converter.convert() # Save the TFLITE model with open('model.tflite', 'wb') as f: f.write(tflite_model) ``` 在上面的代码中,将TensorFlow模型转换为TFLITE模型的步骤是: - 从Keras模型创建一个转换器 - 使用转换器将模型转换为TFLITE格式 - 将TFLITE模型保存到磁盘上 在保存TFLITE模型时,可以将文件名更改为任何你想要的名称。 4. 加载TFLITE模型 使用以下代码加载TFLITE模型: ``` # Load the TFLITE model interpreter = tf.lite.Interpreter(model_path='model.tflite') interpreter.allocate_tensors() ``` 在上面的代码中,使用TFLITE解释器加载模型,并调用“allocate_tensors”方法以分配解释器所需的所有张量。 5. 运行TFLITE模型 使用以下代码在TFLITE模型上运行推理: ``` # Run inference on the TFLITE model input_data = ... # Load input data interpreter.set_tensor(interpreter.get_input_details()[0]['index'], input_data) interpreter.invoke() output_data = interpreter.get_tensor(interpreter.get_output_details()[0]['index']) ``` 在上面的代码中,需要将输入数据加载到“input_data”变量中,并将其设置为TFLITE解释器的输入张量。然后,使用“invoke”方法运行推理,并从解释器的输出张量中获取结果。 以上就是将TensorFlow 1.15模型转换为TFLITE模型的详细代码和操作步骤。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值