ubuntu20.10 tensorflow2.5 将训练后的模型移植到android 平台之自己训练模型运行(三)

9 篇文章 1 订阅
8 篇文章 0 订阅

 环境准备:

ubuntu20.10 + python3 + tensorflow2.x

参考博客:(5条消息) ubuntu20.10 安装低版本tensorflow1.8 或者 tensorflow2.5 步骤_Navy的博客-CSDN博客 详细描述了tensorflw2.5  环境搭建以及训练模型。

参考博客:(5条消息) ubuntu20.10 tensorflow2.5 将训练后的模型移植到android 平台之官网demo 运行(二)_Navy的博客-CSDN博客 下载调试完整的demo。

1. 将目录:/tensorflow2.0/models/research/object_detection/training_tf2 下的模型数据导出 SavedModel (Exports TF2 detection SavedModel for conversion to TensorFlow Lite)。

#tensorflow2.x
python3 export_tflite_graph_tf2.py --pipeline_config_path=training_tf2/ssd_mobilenet_v2_320x320_coco17_tpu-8.config     --trained_checkpoint_dir=training_tf2/ --output_directory=training_tf2/train_export/TFlite

模型数据: 

SavedModel 数据:

 2.  将 SavedModel 数据通过 pb_to_tf2lite.py 文件转换成 navy_tflite.tflite 文件

pb_to_tf2lite.py :

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import tensorflow as tf

saved_model_dir = '/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/saved_model/'
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.experimental_new_converter = True
tflite_model = converter.convert()
open('/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite.tflite', 'wb').write(tflite_model)

 3. 将 navy_tflite.tflite 以及 navy_lable.txt 文件通过 metadata_writer_for_object_detection.py 文件转换成 navy_tflite_metadata.tflite 文件

navy_lable.txt :你所训练的物体,将物体都列出来

cup
curtain

metadata_writer_for_object_detection.py :

"""Writes metadata and label file to the object_detection models."""

from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite.tflite"

# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_lable.txt"
_SAVE_TO_PATH = "/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite_metadata.tflite"

# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters)

_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

 下面着重讲一下为什么要产生navy_tflite_metadata.tflite 这个文件: 

 metadata 一般指元数据。

navy_tflite_metadata.tflite = navy_tflite.tflite + navy_lable.txt (Writes metadata and label file to the object_detection models)简单来说就是将 lable 文件存储在 TFLite 模型模式的 metadata 字段中。

对于官方demo 来说,如果.tflite 文件如果没有添加lable 运行时将会出错:java.lang.IllegalStateException: This model does not contain associated files, and is not a Zip file.

详细信息请查询:https://tensorflow.google.cn/lite/convert/metadata_writer_tutorial?hl=en

 补充: 对于图像分类(image_classification)使用:metadata_writer_for_image_classifier.py

路径:examples-master\lite\examples\image_classification\metadata

"""Writes metadata and label file to the image classifier models."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from absl import app
from absl import flags
import tensorflow as tf

import flatbuffers
# pylint: disable=g-direct-tensorflow-import
from tflite_support import metadata_schema_py_generated as _metadata_fb
from tflite_support import metadata as _metadata
# pylint: enable=g-direct-tensorflow-import

FLAGS = flags.FLAGS


def define_flags():
  flags.DEFINE_string("model_file", None,
                      "Path and file name to the TFLite model file.")
  flags.DEFINE_string("label_file", None, "Path to the label file.")
  flags.DEFINE_string("export_directory", None,
                      "Path to save the TFLite model files with metadata.")
  flags.mark_flag_as_required("model_file")
  flags.mark_flag_as_required("label_file")
  flags.mark_flag_as_required("export_directory")


class ModelSpecificInfo(object):
  """Holds information that is specificly tied to an image classifier."""

  def __init__(self, name, version, image_width, image_height, image_min,
               image_max, mean, std, num_classes, author):
    self.name = name
    self.version = version
    self.image_width = image_width
    self.image_height = image_height
    self.image_min = image_min
    self.image_max = image_max
    self.mean = mean
    self.std = std
    self.num_classes = num_classes
    self.author = author


_MODEL_INFO = {
    "mobilenet_v1_0.75_160_quantized.tflite":
        ModelSpecificInfo(
            name="MobileNetV1 image classifier",
            version="v1",
            image_width=160,
            image_height=160,
            image_min=0,
            image_max=255,
            mean=[127.5],
            std=[127.5],
            num_classes=1001,
            author="TensorFlow")
}


class MetadataPopulatorForImageClassifier(object):
  """Populates the metadata for an image classifier."""

  def __init__(self, model_file, model_info, label_file_path):
    self.model_file = model_file
    self.model_info = model_info
    self.label_file_path = label_file_path
    self.metadata_buf = None

  def populate(self):
    """Creates metadata and then populates it for an image classifier."""
    self._create_metadata()
    self._populate_metadata()

  def _create_metadata(self):
    """Creates the metadata for an image classifier."""

    # Creates model info.
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.name = self.model_info.name
    model_meta.description = ("Identify the most prominent object in the "
                              "image from a set of %d categories." %
                              self.model_info.num_classes)
    model_meta.version = self.model_info.version
    model_meta.author = self.model_info.author
    model_meta.license = ("Apache License. Version 2.0 "
                          "http://www.apache.org/licenses/LICENSE-2.0.")

    # Creates input info.
    input_meta = _metadata_fb.TensorMetadataT()
    input_meta.name = "image"
    input_meta.description = (
        "Input image to be classified. The expected image is {0} x {1}, with "
        "three channels (red, blue, and green) per pixel. Each value in the "
        "tensor is a single byte between {2} and {3}.".format(
            self.model_info.image_width, self.model_info.image_height,
            self.model_info.image_min, self.model_info.image_max))
    input_meta.content = _metadata_fb.ContentT()
    input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
    input_meta.content.contentProperties.colorSpace = (
        _metadata_fb.ColorSpaceType.RGB)
    input_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.ImageProperties)
    input_normalization = _metadata_fb.ProcessUnitT()
    input_normalization.optionsType = (
        _metadata_fb.ProcessUnitOptions.NormalizationOptions)
    input_normalization.options = _metadata_fb.NormalizationOptionsT()
    input_normalization.options.mean = self.model_info.mean
    input_normalization.options.std = self.model_info.std
    input_meta.processUnits = [input_normalization]
    input_stats = _metadata_fb.StatsT()
    input_stats.max = [self.model_info.image_max]
    input_stats.min = [self.model_info.image_min]
    input_meta.stats = input_stats

    # Creates output info.
    output_meta = _metadata_fb.TensorMetadataT()
    output_meta.name = "probability"
    output_meta.description = "Probabilities of the %d labels respectively." % self.model_info.num_classes
    output_meta.content = _metadata_fb.ContentT()
    output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
    output_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.FeatureProperties)
    output_stats = _metadata_fb.StatsT()
    output_stats.max = [1.0]
    output_stats.min = [0.0]
    output_meta.stats = output_stats
    label_file = _metadata_fb.AssociatedFileT()
    label_file.name = os.path.basename(self.label_file_path)
    label_file.description = "Labels for objects that the model can recognize."
    label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
    output_meta.associatedFiles = [label_file]

    # Creates subgraph info.
    subgraph = _metadata_fb.SubGraphMetadataT()
    subgraph.inputTensorMetadata = [input_meta]
    subgraph.outputTensorMetadata = [output_meta]
    model_meta.subgraphMetadata = [subgraph]

    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    self.metadata_buf = b.Output()

  def _populate_metadata(self):
    """Populates metadata and label file to the model file."""
    populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
    populator.load_metadata_buffer(self.metadata_buf)
    populator.load_associated_files([self.label_file_path])
    populator.populate()


def main(_):
  model_file = FLAGS.model_file
  model_basename = os.path.basename(model_file)
  if model_basename not in _MODEL_INFO:
    raise ValueError(
        "The model info for, {0}, is not defined yet.".format(model_basename))

  export_model_path = os.path.join(FLAGS.export_directory, model_basename)

  # Copies model_file to export_path.
  tf.io.gfile.copy(model_file, export_model_path, overwrite=False)

  # Generate the metadata objects and put them in the model file
  populator = MetadataPopulatorForImageClassifier(
      export_model_path, _MODEL_INFO.get(model_basename), FLAGS.label_file)
  populator.populate()

  # Validate the output model file by reading the metadata and produce
  # a json file with the metadata under the export path
  displayer = _metadata.MetadataDisplayer.with_model_file(export_model_path)
  export_json_file = os.path.join(FLAGS.export_directory,
                                  os.path.splitext(model_basename)[0] + ".json")
  json_file = displayer.get_metadata_json()
  with open(export_json_file, "w") as f:
    f.write(json_file)

  print("Finished populating metadata and associated file to the model:")
  print(model_file)
  print("The metadata json file has been saved to:")
  print(export_json_file)
  print("The associated file that has been been packed to the model is:")
  print(displayer.get_packed_associated_file_list())


if __name__ == "__main__":
  define_flags()
  app.run(main)

4. 将 navy_tflite_metadata.tflite以及 navy_lable.txt 文件复制到官方demo 的 src\main\assets将demo 中 DetectorActivity.java 以及 DetectorTest.java 中两处加载 lite-model_ssd_mobilenet_v1_1_metadata_2.tflite 以及  labelmap.txt  地方修改成 navy_tflite_metadata.tflite以及 navy_lable.txt 。

 点击安装运行app,app 闪退。logcat 出现错误:java.lang.IllegalArgumentException: Cannot copy to a TensorFlowLite tensor (serving_default_input:0) with 1080000 bytes from a Java Buffer with 270000 bytes.

解决方法:将demo 中 DetectorActivity.java 以及 DetectorTest.java 中两处定义修改参数为false:

#DetectorActivity.java
private static final boolean TF_OD_API_IS_QUANTIZED = false;
#DetectorTest.java
private static final boolean IS_MODEL_QUANTIZED = false;

这里就涉及到 tensor 接收图片的大小以及量化问题:

查看.tflite 文件的输入输出详细类型信息:

import tensorflow as tf

model_path = '/tensorflow2.0/models/research/object_detection/training_tf2/train_export/TFlite/navy_tflite.tflite'

interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# 获取输入和输出张量。
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("==================input===================")
print(input_details)
print("==================output==================")
print(output_details)
==================input===================
[{'name': 'serving_default_input:0', 'index': 0, 'shape': array([  1, 300, 300,   3], dtype=int32), 'shape_signature': array([  1, 300, 300,   3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
==================output==================
[{'name': 'StatefulPartitionedCall:3', 'index': 247, 'shape': array([ 1, 10,  4], dtype=int32), 'shape_signature': array([ 1, 10,  4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:2', 'index': 248, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:1', 'index': 249, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:0', 'index': 250, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

 详细分析.tflite 文件run 时需要输入输出的数据类型:

input :一个节点表示只有一个输入

节点名称:serving_default_input:0

输入形状:'shape': array([  1, 300, 300,   3] 表示图像的大小是300x300  3通道

因此在对图片进行识别的时候,要将图片放缩到300x300 ,再交给.tflite 文件进行处理。

数据类型:'dtype': <class 'numpy.float32'> 表示数据类型,表示浮点量化

数据类型关系到图像资料的处理:不同的量化需要不同的处理方式。

bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

    imgData.rewind();
    for (int i = 0; i < inputSize; ++i) {
      for (int j = 0; j < inputSize; ++j) {
        int pixelValue = intValues[i * inputSize + j];
        if (isModelQuantized) {
          // Quantized model
          imgData.put((byte) ((pixelValue >> 16) & 0xFF));
          imgData.put((byte) ((pixelValue >> 8) & 0xFF));
          imgData.put((byte) (pixelValue & 0xFF));
        } else { // Float model
          imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
          imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
          imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
        }
      }
    }

output: 四个节点,表示四个输出

节点名称0:StatefulPartitionedCall:3  array([ 1, 10,  4], dtype=int32)   对象检测的矩形框坐标

节点名称1:StatefulPartitionedCall:2  array([ 1, 10 ], dtype=int32)   对象检测的索引号

节点名称2:StatefulPartitionedCall:1  array([ 1, 10 ], dtype=int32)   对象检测的概率

节点名称3:StatefulPartitionedCall:0  array([ 1 ], dtype=int32)   对象检测的物体总和

 总结:使用.tflite 文件的时候,要先查看量化类型,在demo 中做相应的处理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值