Tensorflow Lite Model Maker --- 图像分类篇+源码

这篇博客介绍了如何使用TensorFlow Lite Model Maker库将预训练的TensorFlow模型转换为适用于移动端的TFLite模型,支持图像分类、物体检测、文本分类等多种任务。通过示例展示了数据加载、模型训练、评估和量化的过程,以减少模型大小和提高推理速度。量化后的模型在保持高准确率的同时,显著降低了文件大小并提升了推断效率。
摘要由CSDN通过智能技术生成

TFLite_tutorials

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.
解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署

下表罗列的是 TFLite Model Maker 目前支持的几个任务类型

Supported TasksTask Utility
Image Classification: tutorial, apiClassify images into predefined categories.
Object Detection: tutorial, apiDetect objects in real time.
Text Classification: tutorial, apiClassify text into predefined categories.
BERT Question Answer: tutorial, apiFind the answer in a certain context for a given question with BERT.
Audio Classification: tutorial, apiClassify audio into predefined categories.
Recommendation: demo, apiRecommend items based on the context information for on-device scenario.

If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model.
解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite

想用使用 Tensorflow Lite Model Maker 我们需要先安装:

pip install tflite-model-maker

本质完成的是分类任务
更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等

下面的代码片段是用于下载数据集的

image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

数据集结构如下所示:
flower_photos
|__ daisy
|______ 100080576_f52e8ee070_n.jpg
|______ 14167534527_781ceb1b7a_n.jpg
|______ …
|__ dandelion
|______ 10043234166_e6dd915111_n.jpg
|______ 1426682852_e62169221f_m.jpg
|______ …
|__ roses
|______ 102501987_3cdb8e5394_n.jpg
|______ 14982802401_a3dfb22afb.jpg
|______ …
|__ sunflowers
|______ 12471791574_bb1be83df4.jpg
|______ 15122112402_cafa41934f.jpg
|______ …
|__ tulips
|______ 13976522214_ccec508fe7.jpg
|______ 14487943607_651e8062a1_m.jpg
|______ …

加载数据集并切分

data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
assert tf.__version__.startswith('2')

判断是否为 ‘2’ 开头

模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律

import os
import time

import numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt

assert tf.__version__.startswith('2')

image_path = tf.keras.utils.get_file(
    'flower_photos.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')

data = DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
#     print(batch)
#     break

validation_data, test_data = rest_data.split(0.5)

model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('efficientnet_lite0'), epochs=20)

loss, accuracy = model.evaluate(test_data)

model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)

从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0)
从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s

config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))

如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多

model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=model_spec.get('mobilenet_v2'), epochs=20)

将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s

inception_v3_spec = image_classifier.ModelSpec(
    uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=validation_data,
                                model_spec=inception_v3_spec, epochs=20)

将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s

Common Dataset used for tasks.

class DataLoader(object):
  """This class provides generic utilities for loading customized domain data that will be used later in model retraining.

  For different ML problems or tasks, such as image classification, text
  classification etc., a subclass is provided to handle task-specific data
  loading requirements.
  """

  def __init__(self, dataset, size):
    """Init function for class `DataLoader`.

    In most cases, one should use helper functions like `from_folder` to create
    an instance of this class.

    Args:
      dataset: A tf.data.Dataset object that contains a potentially large set of
        elements, where each element is a pair of (input_data, target). The
        `input_data` means the raw input data, like an image, a text etc., while
        the `target` means some ground truth of the raw input data, such as the
        classification label of the image etc.
      size: The size of the dataset. tf.data.Dataset donesn't support a function
        to get the length directly since it's lazy-loaded and may be infinite.
    """
    self._dataset = dataset
    self._size = size

  def gen_dataset(self,
                  batch_size=1,
                  is_training=False,
                  shuffle=False,
                  input_pipeline_context=None,
                  preprocess=None,
                  drop_remainder=False):
    """Generate a shared and batched tf.data.Dataset for training/evaluation.

Image dataloader

class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
  """DataLoader for image classifier."""

  @classmethod
  def from_folder(cls, filename, shuffle=True):
    """Image analysis for image classification load images with labels.

    Assume the image data of the same label are in the same subdirectory.

    Args:
      filename: Name of the file.
      shuffle: boolean, if shuffle, random shuffle data.

    Returns:
      ImageDataset containing images and labels and other related info.
    """
   @classmethod
   def from_tfds(cls, name):
     """Loads data from tensorflow_datasets."""

ImageNet preprocessing

class Preprocessor(object):
  """Preprocessing for image classification."""

  def __init__(self,
               input_shape,
               num_classes,
               mean_rgb,
               stddev_rgb,
               use_augmentation=False):
    self.input_shape = input_shape
    self.num_classes = num_classes
    self.mean_rgb = mean_rgb
    self.stddev_rgb = stddev_rgb
    self.use_augmentation = use_augmentation

  def __call__(self, image, label, is_training=True):
    if self.use_augmentation:
      return self._preprocess_with_augmentation(image, label, is_training)
    return self._preprocess_without_augmentation(image, label)

  def _preprocess_with_augmentation(self, image, label, is_training):
    """Image preprocessing method with data augmentation."""
    image_size = self.input_shape[0]
    if is_training:
      image = preprocess_for_train(image, image_size)
    else:
      image = preprocess_for_eval(image, image_size)

    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)

    label = tf.one_hot(label, depth=self.num_classes)
    return image, label

  # TODO(yuqili): Changes to preprocess to support batch input.
  def _preprocess_without_augmentation(self, image, label):
    """Image preprocessing method without data augmentation."""
    image = tf.cast(image, tf.float32)

    image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
    image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)

    image = tf.compat.v1.image.resize(image, self.input_shape)
    label = tf.one_hot(label, depth=self.num_classes)
    return image, label

class ImageClassifier(classification_model.ClassificationModel):
  """ImageClassifier class for inference and exporting to tflite."""

  def __init__(self,
               model_spec,
               index_to_label,
               shuffle=True,
               hparams=hub_lib.get_default_hparams(),
               use_augmentation=False,
               representative_data=None):
    """Init function for ImageClassifier class.

    Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      hparams: A namedtuple of hyperparameters. This function expects
        .dropout_rate: The fraction of the input units to drop, used in dropout
          layer.
        .do_fine_tuning: If true, the Hub module is trained together with the
          classification layer on top.
      use_augmentation: Use data augmentation for preprocessing.
      representative_data:  Representative dataset for full integer
        quantization. Used when converting the keras model to the TFLite model
        with full interger quantization.
    """
    super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
                                          hparams.do_fine_tuning)
    num_classes = len(index_to_label)
    self._hparams = hparams
    self.preprocess = image_preprocessing.Preprocessor(
        self.model_spec.input_image_shape,
        num_classes,
        self.model_spec.mean_rgb,
        self.model_spec.stddev_rgb,
        use_augmentation=use_augmentation)
    self.history = None  # Training history that returns from `keras_model.fit`.
    self.representative_data = representative_data

  def _get_tflite_input_tensors(self, input_tensors):
    """Gets the input tensors for the TFLite model."""
    return input_tensors

  def create_model(self, hparams=None, with_loss_and_metrics=False):
    """Creates the classifier model for retraining."""
    hparams = self._get_hparams_or_default(hparams)

    module_layer = hub_loader.HubKerasLayerV1V2(
        self.model_spec.uri, trainable=hparams.do_fine_tuning)
    self.model = hub_lib.build_model(module_layer, hparams,
                                     self.model_spec.input_image_shape,
                                     self.num_classes)
    if with_loss_and_metrics:
      # Adds loss and metrics in the keras model.
      self.model.compile(
          loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
          metrics=['accuracy'])

Custom classification model that is already retained by data

class ClassificationModel(custom_model.CustomModel):
  """"The abstract base class that represents a Tensorflow classification model."""

  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
                           ExportFormat.SAVED_MODEL, ExportFormat.TFJS)

  def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
    """Initialize a instance with data, deploy mode and other related parameters.

    Args:
      model_spec: Specification for the model.
      index_to_label: A list that map from index to label class name.
      shuffle: Whether the data should be shuffled.
      train_whole_model: If true, the Hub module is trained together with the
        classification layer on top. Otherwise, only train the top
        classification layer.
    """
    super(ClassificationModel, self).__init__(model_spec, shuffle)
    self.index_to_label = index_to_label
    self.num_classes = len(index_to_label)
    self.train_whole_model = train_whole_model

  def evaluate(self, data, batch_size=32):
    """Evaluates the model.

    Args:
      data: Data to be evaluated.
      batch_size: Number of samples per evaluation step.

    Returns:
      The loss value and accuracy.
    """
    ds = data.gen_dataset(
        batch_size, is_training=False, preprocess=self.preprocess)
    return self.model.evaluate(ds)

  def predict_top_k(self, data, k=1, batch_size=32):
    """Predicts the top-k predictions.

class CustomModel(abc.ABC):
  """"The abstract base class that represents a Tensorflow classification model."""

  DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
  ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
                           ExportFormat.TFJS)

  def __init__(self, model_spec, shuffle):
    """Initialize a instance with data, deploy mode and other related parameters.

    Args:
      model_spec: Specification for the model.
      shuffle: Whether the training data should be shuffled.
    """
    self.model_spec = model_spec
    self.shuffle = shuffle
    self.model = None
    # TODO(yuqili): remove this method once preprocess for image classifier is
    # also moved to DataLoader part.
    self.preprocess = None

  @abc.abstractmethod
  def train(self, train_data, validation_data=None, **kwargs):
    return

  def summary(self):
    self.model.summary()

  @abc.abstractmethod
  def evaluate(self, data, **kwargs):
    return

def export_tflite(model,
                  tflite_filepath,
                  quantization_config=None,
                  convert_from_saved_model_tf2=False,
                  preprocess=None,
                  supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
  """Converts the retrained model to tflite format and saves it.

  Args:
    model: model to be converted to tflite.
    tflite_filepath: File path to save tflite model.
    quantization_config: Configuration for post-training quantization.
    convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
    preprocess: A preprocess function to apply on the dataset.
        # TODO(wangtz): Remove when preprocess is split off from CustomModel.
    supported_ops: A list of supported ops in the converted TFLite file.
  """
  if tflite_filepath is None:
    raise ValueError(
        "TFLite filepath couldn't be None when exporting to tflite.")

  if compat.get_tf_behavior() == 1:
    lite = tf.compat.v1.lite
  else:
    lite = tf.lite

  convert_from_saved_model = (
      compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
  with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
    if temp_dir_name:
      save_path = os.path.join(temp_dir_name, 'saved_model')
      model.save(save_path, include_optimizer=False, save_format='tf')
      converter = lite.TFLiteConverter.from_saved_model(save_path)
    else:
      converter = lite.TFLiteConverter.from_keras_model(model)

    if quantization_config:
      converter = quantization_config.get_converter_with_quantization(
          converter, preprocess=preprocess)

    converter.target_spec.supported_ops = supported_ops
    tflite_model = converter.convert()

  with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
    f.write(tflite_model)


def get_lite_runner(tflite_filepath, model_spec=None):
  """Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
  # Gets the functions to handle the input & output indexes if exists.
  reorder_input_details_fn = None
  if hasattr(model_spec, 'reorder_input_details'):
    reorder_input_details_fn = model_spec.reorder_input_details

  reorder_output_details_fn = None
  if hasattr(model_spec, 'reorder_output_details'):
    reorder_output_details_fn = model_spec.reorder_output_details

  lite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
                           reorder_output_details_fn)
  return lite_runner
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值