1、什么是Tensorflow Lite?
TensorFlow Lite 是一组工具,可帮助开发者在移动设备、嵌入式设备和 IoT 设备上运行 TensorFlow 模型。它支持设备端机器学习推断,延迟较低,并且二进制文件很小。
2、Tensorflow Lite的开发流程
Tensorflow Lite包含两个组件,分别是:
Tensorflow Lite转换器
转换器的目的是将Tensorflow模型转换成可供Tensorflow Lite解释器可用的模型格式,并可引入优化以减小二进制文件的大小和提高性能。
Tensorflow Lite解释器
它可在手机、嵌入式 Linux 设备和微控制器等很多不同类型的硬件上运行经过专门优化的模型。
使用 TensorFlow Lite 的工作流包括如下步骤:
2.1 选择模型
2.2 转换模型
使用Tensorflow Lite转换器将模型转换微Tensorflow Lite的格式。
2.3 部署模型
使用Tensorflow Lite解释器在设备端运行模型
2.4 优化模型
使用模型优化工具包缩减模型大小并提高效率,同时最大限度降低对准确率的影响。
3、代码示例讲解
3.1选择模型
本例子用tensorflow自带的mobilenet网络模型结构。
import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
model = MobileNetV2(weights='imagenet')
model_save_path = 'model//tf_keras_mobilenet'
model.save(filepath=model_save_path)
print(model.summary())
2、转换模型
使用Tensorflow Lite转换器将模型转换成Tensorflow Lite格式的模型
import tensorflow as tf
# save tf_lite
# Converting a tf.Keras model to a TensorFlow Lite model.
model = 'model//tf_keras_mobilenet'
converter = tf.lite.TFLiteConverter.from_saved_model(model)
# # float16
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.compat.v1.lite.constants.FLOAT16]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.allow_custom_ops = True
tflite_model = converter.convert()
open('model//tf_keras_mobilenet//mobilenet_model.tflite', 'wb').write(tflite_model)
3、部署模型
使用Tensorflow Lite解释器来推理模型
import tensorflow as tf
import numpy as np
import cv2
import time
tf_lite_path = 'model//tf_keras_mobilenet//mobilenet_model.tflite'
interpreter = tf.lite.Interpreter(model_path=tf_lite_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
print('input_details:',input_details)
output_details = interpreter.get_output_details()
print('output_details:',output_details)
start = time.time()
img = cv2.imread('girl.jpg')
img = img.astype(np.float32)
print(img.dtype)
img = cv2.resize(img,dsize=(224,224))
input_data = tf.expand_dims(img,axis=0)
print('input_data.shape:',input_data.shape)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
print('time:',time.time()-start) # float16 time: 0.12467169761657715
print(output_data)
4、优化模型还没吃透,再研究研究