背景
TensorFlow Lite 转换器可根据输入的 TensorFlow 模型生成 TensorFlow Lite 模型(一种优化的 FlatBuffer 格式,以 .tflite
为文件扩展名). 作用是进一步缩短模型延迟时间和减小模型大小,同时最大限度降低准确率损失和添加元数据,从而在设备上部署模型时可以更轻松地创建平台专用封装容器代码。
环境
tensorflow=2.4.1
实践例子
把Tensorflow的模型转换成tflite
import tensorflow as tf
def convert_to_tflite(model):
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tfmodel = converter.convert()
file = open('yourmodel.tflite', 'wb')
file.write(tfmodel)
file.close()
运行Tflite模型
import tensorflow as tf
def run_reference_by_tflite(input):
interpreter = tf.lite.Interpreter(model_path="yourmodel.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# input details
print(input_details, len(input_details))
# output details
print(output_details)
# input_details[0]['index'] = the index which accepts the input
interpreter.set_tensor(input_details[0]['index'], input)
# run the inference
interpreter.invoke()
# output_details[0]['index'] = the index which provides the input
output_data = interpreter.get_tensor(output_details[0]['index'])
print('interpreter: ', output_data)
转tfhub中的模型
!pip install -q tensorflow-text
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
hub.load("https://tfhub.dev/google/universal-sentence-encoder-multilingual/3") # Caches the model in /tmp/tfhub_modules
converter = tf.lite.TFLiteConverter.from_saved_model("/tmp/tfhub_modules/26c892ffbc8d7b032f5a95f316e2841ed4f1608c")
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops.
]
tflite_file = "model.tflite"
with open(tflite_file, 'wb') as f:
f.write(converter.convert())
interpreter = tf.lite.Interpreter(tflite_file)
interpreter.get_signature_list() # {'serving_default': {'inputs': ['inputs'], 'outputs': ['outputs']}}
不同版本的tensorflow或不同的格式的模型对应的转换方法
-
使用 tf.lite.TFLiteConverter 转换 TensorFlow 2.x 模型。TensorFlow 2.x 模型是使用 SavedModel 格式存储的,并通过高阶
tf.keras.*
API(Keras 模型)或低阶tf.*
API(用于生成具体函数)生成。因此,您有以下三个选项(示例包含在接下来的几节中): -
使用 tf.compat.v1.lite.TFLiteConverter 转换 TensorFlow 1.x 模型(示例位于 GitHub 上):
- tf.compat.v1.lite.TFLiteConverter.from_saved_model():转换 SavedModel。
- tf.compat.v1.lite.TFLiteConverter.from_keras_model_file():转换 Keras 模型。
- tf.compat.v1.lite.TFLiteConverter.from_session():从会话转换 GraphDef。
- tf.compat.v1.lite.TFLiteConverter.from_frozen_graph():从文件转换 Frozen GraphDef。如果您有检查点,请先将其转换为 Frozen GraphDef 文件,然后使用此 API(如此处所示)。
相关错误以及解决方法
1. ValueError: Cannot set tensor: Got value of type NOTYPE but expected type FLOAT32 for input 0,
或
ValueError: Cannot set tensor: Got value of type INT32 but expected type FLOAT32 for input 0,
或
ValueError: Cannot set tensor: Got value of type UINT8 but expected type FLOAT32 for input 0, name: input_1
解决方法:
under_exp = np.array(under_exp, dtype=np.float32)
参考资料
python - How to convert keras(h5) file to a tflite file? - Stack Overflowhttps://www.tensorflow.org/lite/convert