TFLite_tutorials
The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow neural-network model to particular input data when deploying this model for on-device ML applications.
解读: 此处我们想要得到的是 .tflite 格式的模型,用于在移动端或者嵌入式设备上进行部署
下表罗列的是 TFLite Model Maker 目前支持的几个任务类型
Supported Tasks | Task Utility |
---|---|
Image Classification: tutorial, api | Classify images into predefined categories. |
Object Detection: tutorial, api | Detect objects in real time. |
Text Classification: tutorial, api | Classify text into predefined categories. |
BERT Question Answer: tutorial, api | Find the answer in a certain context for a given question with BERT. |
Audio Classification: tutorial, api | Classify audio into predefined categories. |
Recommendation: demo, api | Recommend items based on the context information for on-device scenario. |
If your tasks are not supported, please first use TensorFlow to retrain a TensorFlow model with transfer learning (following guides like images, text, audio) or train it from scratch, and then convert it to TensorFlow Lite model.
解读: 如果你要训练的模型不符合上述的任务类型,那么可以先训练 Tensorflow Model 然后再转换成 TFLite
想用使用 Tensorflow Lite Model Maker 我们需要先安装:
pip install tflite-model-maker
本质完成的是分类任务
更换不同的模型,看最终的准确率,以及 TFLite 的大小、推断速度、内存占用、CPU占用等
下面的代码片段是用于下载数据集的
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
数据集结构如下所示:
flower_photos
|__ daisy
|______ 100080576_f52e8ee070_n.jpg
|______ 14167534527_781ceb1b7a_n.jpg
|______ …
|__ dandelion
|______ 10043234166_e6dd915111_n.jpg
|______ 1426682852_e62169221f_m.jpg
|______ …
|__ roses
|______ 102501987_3cdb8e5394_n.jpg
|______ 14982802401_a3dfb22afb.jpg
|______ …
|__ sunflowers
|______ 12471791574_bb1be83df4.jpg
|______ 15122112402_cafa41934f.jpg
|______ …
|__ tulips
|______ 13976522214_ccec508fe7.jpg
|______ 14487943607_651e8062a1_m.jpg
|______ …
加载数据集并切分
data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)
assert tf.__version__.startswith('2')
判断是否为 ‘2’ 开头
模型训练结果 train_acc = 0.9698, val_acc = 0.9375, test_acc = 0.9210 总体来说符合模型的泛化规律
import os
import time
import numpy as np
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
import matplotlib.pyplot as plt
assert tf.__version__.startswith('2')
image_path = tf.keras.utils.get_file(
'flower_photos.tgz',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
extract=True)
image_path = os.path.join(os.path.dirname(image_path), 'flower_photos')
data = DataLoader.from_folder(image_path)
# data = data.gen_dataset(batch_size=1)
train_data, rest_data = data.split(0.8)
# for batch in data.take(1):
# print(batch)
# break
validation_data, test_data = rest_data.split(0.5)
model = image_classifier.create(train_data, validation_data=validation_data,
model_spec=model_spec.get('efficientnet_lite0'), epochs=20)
loss, accuracy = model.evaluate(test_data)
model.export(export_dir='./testTFlite', export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
start = time.time()
print(model.evaluate_tflite('./testTFlite/model.tflite', test_data))
end = time.time()
print('elapsed time: ', end - start)
从上面的输出日志来看,模型经过量化后,准确率并未有多少损失,量化后的模型大小为 4.0MB(efficientnet_lite0)
从下图来看,是单 cpu 在做推断,test_data 的图片有 367 张,总耗时 273.43s
config = QuantizationConfig.for_float16()
model.export(export_dir='./testTFlite', tflite_filename='model_fp16.tflite', quantization_config=config, export_format=(ExportFormat.TFLITE, ExportFormat.LABEL))
如果导出的模型是 fp16 的话,模型大小为 6.8MB(efficientnet_lite0),推断速度是 5.54 s,快了很多
model = image_classifier.create(train_data, validation_data=validation_data,
model_spec=model_spec.get('mobilenet_v2'), epochs=20)
将模型切换为 mobilenet_v2,导出的 fp16 模型大小为 4.6MB,推断速度是 4.36 s
inception_v3_spec = image_classifier.ModelSpec(
uri='https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1')
inception_v3_spec.input_image_shape = [299, 299]
model = image_classifier.create(train_data, validation_data=validation_data,
model_spec=inception_v3_spec, epochs=20)
将模型切换为 inception_v3,导出的 fp16 模型大小为 43.8MB(inception_v3),推断速度是 25.31 s
Common Dataset used for tasks.
class DataLoader(object):
"""This class provides generic utilities for loading customized domain data that will be used later in model retraining.
For different ML problems or tasks, such as image classification, text
classification etc., a subclass is provided to handle task-specific data
loading requirements.
"""
def __init__(self, dataset, size):
"""Init function for class `DataLoader`.
In most cases, one should use helper functions like `from_folder` to create
an instance of this class.
Args:
dataset: A tf.data.Dataset object that contains a potentially large set of
elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means some ground truth of the raw input data, such as the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = dataset
self._size = size
def gen_dataset(self,
batch_size=1,
is_training=False,
shuffle=False,
input_pipeline_context=None,
preprocess=None,
drop_remainder=False):
"""Generate a shared and batched tf.data.Dataset for training/evaluation.
Image dataloader
class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
"""DataLoader for image classifier."""
@classmethod
def from_folder(cls, filename, shuffle=True):
"""Image analysis for image classification load images with labels.
Assume the image data of the same label are in the same subdirectory.
Args:
filename: Name of the file.
shuffle: boolean, if shuffle, random shuffle data.
Returns:
ImageDataset containing images and labels and other related info.
"""
@classmethod
def from_tfds(cls, name):
"""Loads data from tensorflow_datasets."""
ImageNet preprocessing
class Preprocessor(object):
"""Preprocessing for image classification."""
def __init__(self,
input_shape,
num_classes,
mean_rgb,
stddev_rgb,
use_augmentation=False):
self.input_shape = input_shape
self.num_classes = num_classes
self.mean_rgb = mean_rgb
self.stddev_rgb = stddev_rgb
self.use_augmentation = use_augmentation
def __call__(self, image, label, is_training=True):
if self.use_augmentation:
return self._preprocess_with_augmentation(image, label, is_training)
return self._preprocess_without_augmentation(image, label)
def _preprocess_with_augmentation(self, image, label, is_training):
"""Image preprocessing method with data augmentation."""
image_size = self.input_shape[0]
if is_training:
image = preprocess_for_train(image, image_size)
else:
image = preprocess_for_eval(image, image_size)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
# TODO(yuqili): Changes to preprocess to support batch input.
def _preprocess_without_augmentation(self, image, label):
"""Image preprocessing method without data augmentation."""
image = tf.cast(image, tf.float32)
image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
image = tf.compat.v1.image.resize(image, self.input_shape)
label = tf.one_hot(label, depth=self.num_classes)
return image, label
class ImageClassifier(classification_model.ClassificationModel):
"""ImageClassifier class for inference and exporting to tflite."""
def __init__(self,
model_spec,
index_to_label,
shuffle=True,
hparams=hub_lib.get_default_hparams(),
use_augmentation=False,
representative_data=None):
"""Init function for ImageClassifier class.
Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
shuffle: Whether the data should be shuffled.
hparams: A namedtuple of hyperparameters. This function expects
.dropout_rate: The fraction of the input units to drop, used in dropout
layer.
.do_fine_tuning: If true, the Hub module is trained together with the
classification layer on top.
use_augmentation: Use data augmentation for preprocessing.
representative_data: Representative dataset for full integer
quantization. Used when converting the keras model to the TFLite model
with full interger quantization.
"""
super(ImageClassifier, self).__init__(model_spec, index_to_label, shuffle,
hparams.do_fine_tuning)
num_classes = len(index_to_label)
self._hparams = hparams
self.preprocess = image_preprocessing.Preprocessor(
self.model_spec.input_image_shape,
num_classes,
self.model_spec.mean_rgb,
self.model_spec.stddev_rgb,
use_augmentation=use_augmentation)
self.history = None # Training history that returns from `keras_model.fit`.
self.representative_data = representative_data
def _get_tflite_input_tensors(self, input_tensors):
"""Gets the input tensors for the TFLite model."""
return input_tensors
def create_model(self, hparams=None, with_loss_and_metrics=False):
"""Creates the classifier model for retraining."""
hparams = self._get_hparams_or_default(hparams)
module_layer = hub_loader.HubKerasLayerV1V2(
self.model_spec.uri, trainable=hparams.do_fine_tuning)
self.model = hub_lib.build_model(module_layer, hparams,
self.model_spec.input_image_shape,
self.num_classes)
if with_loss_and_metrics:
# Adds loss and metrics in the keras model.
self.model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
metrics=['accuracy'])
Custom classification model that is already retained by data
class ClassificationModel(custom_model.CustomModel):
""""The abstract base class that represents a Tensorflow classification model."""
DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL)
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.LABEL,
ExportFormat.SAVED_MODEL, ExportFormat.TFJS)
def __init__(self, model_spec, index_to_label, shuffle, train_whole_model):
"""Initialize a instance with data, deploy mode and other related parameters.
Args:
model_spec: Specification for the model.
index_to_label: A list that map from index to label class name.
shuffle: Whether the data should be shuffled.
train_whole_model: If true, the Hub module is trained together with the
classification layer on top. Otherwise, only train the top
classification layer.
"""
super(ClassificationModel, self).__init__(model_spec, shuffle)
self.index_to_label = index_to_label
self.num_classes = len(index_to_label)
self.train_whole_model = train_whole_model
def evaluate(self, data, batch_size=32):
"""Evaluates the model.
Args:
data: Data to be evaluated.
batch_size: Number of samples per evaluation step.
Returns:
The loss value and accuracy.
"""
ds = data.gen_dataset(
batch_size, is_training=False, preprocess=self.preprocess)
return self.model.evaluate(ds)
def predict_top_k(self, data, k=1, batch_size=32):
"""Predicts the top-k predictions.
class CustomModel(abc.ABC):
""""The abstract base class that represents a Tensorflow classification model."""
DEFAULT_EXPORT_FORMAT = (ExportFormat.TFLITE)
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
ExportFormat.TFJS)
def __init__(self, model_spec, shuffle):
"""Initialize a instance with data, deploy mode and other related parameters.
Args:
model_spec: Specification for the model.
shuffle: Whether the training data should be shuffled.
"""
self.model_spec = model_spec
self.shuffle = shuffle
self.model = None
# TODO(yuqili): remove this method once preprocess for image classifier is
# also moved to DataLoader part.
self.preprocess = None
@abc.abstractmethod
def train(self, train_data, validation_data=None, **kwargs):
return
def summary(self):
self.model.summary()
@abc.abstractmethod
def evaluate(self, data, **kwargs):
return
def export_tflite(model,
tflite_filepath,
quantization_config=None,
convert_from_saved_model_tf2=False,
preprocess=None,
supported_ops=(tf.lite.OpsSet.TFLITE_BUILTINS,)):
"""Converts the retrained model to tflite format and saves it.
Args:
model: model to be converted to tflite.
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization.
convert_from_saved_model_tf2: Convert to TFLite from saved_model in TF 2.x.
preprocess: A preprocess function to apply on the dataset.
# TODO(wangtz): Remove when preprocess is split off from CustomModel.
supported_ops: A list of supported ops in the converted TFLite file.
"""
if tflite_filepath is None:
raise ValueError(
"TFLite filepath couldn't be None when exporting to tflite.")
if compat.get_tf_behavior() == 1:
lite = tf.compat.v1.lite
else:
lite = tf.lite
convert_from_saved_model = (
compat.get_tf_behavior() == 1 or convert_from_saved_model_tf2)
with _create_temp_dir(convert_from_saved_model) as temp_dir_name:
if temp_dir_name:
save_path = os.path.join(temp_dir_name, 'saved_model')
model.save(save_path, include_optimizer=False, save_format='tf')
converter = lite.TFLiteConverter.from_saved_model(save_path)
else:
converter = lite.TFLiteConverter.from_keras_model(model)
if quantization_config:
converter = quantization_config.get_converter_with_quantization(
converter, preprocess=preprocess)
converter.target_spec.supported_ops = supported_ops
tflite_model = converter.convert()
with tf.io.gfile.GFile(tflite_filepath, 'wb') as f:
f.write(tflite_model)
def get_lite_runner(tflite_filepath, model_spec=None):
"""Gets `LiteRunner` from file path to TFLite model and `model_spec`."""
# Gets the functions to handle the input & output indexes if exists.
reorder_input_details_fn = None
if hasattr(model_spec, 'reorder_input_details'):
reorder_input_details_fn = model_spec.reorder_input_details
reorder_output_details_fn = None
if hasattr(model_spec, 'reorder_output_details'):
reorder_output_details_fn = model_spec.reorder_output_details
lite_runner = LiteRunner(tflite_filepath, reorder_input_details_fn,
reorder_output_details_fn)
return lite_runner