TensorFlow Lite 开发手册(3)——模型转换

(一)模型转换简介

(1)工作流程

基本工作流程如下:
在这里插入图片描述
TensorFlow Lite 提供以下三种模型转换方法:

  • tf.lite.TFLiteConverter.from_keras_model(),转换实例化的Keras模型
  • tf.lite.TFLiteConverter.from_saved_model(),转换pb文件
  • tf.lite.TFLiteConverter.from_concrete_functions(),转换具体的函数

(2)转换示例模型

import numpy as np
# 转换模型。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open('/home/ai/converted_model.tflite', 'wb').write(tflite_model)

最终得到转化后的模型——converted_model.tflite

(二)模型调用

(1) Python 接口

# 加载 TFLite 模型并分配张量(tensor)。
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# 获取输入和输出张量。
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 使用随机数据作为输入测试 TensorFlow Lite 模型。
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# 函数 `get_tensor()` 会返回一份张量的拷贝。
# 使用 `tensor()` 获取指向张量的指针。
tflite_results = interpreter.get_tensor(output_details[0]['index'])

# 使用随机数据作为输入测试 TensorFlow 模型。
tf_results = model(tf.constant(input_data))

print("tflite result:", tflite_results)
print("tf result:", tf_results)

输出结果如下:

tflite result: [[2.80235346e-09 9.99904633e-01 1.81704600e-05 2.76264545e-09
  8.59975898e-06 3.23225287e-08 6.01521315e-05 1.01964176e-07
  8.37052448e-06 3.09832138e-09]]
tf result: tf.Tensor(
[[2.8023428e-09 9.9990463e-01 1.8170409e-05 2.7626350e-09 8.5997262e-06
  3.2322404e-08 6.0152019e-05 1.0196379e-07 8.3704927e-06 3.0983038e-09]], shape=(1, 10), dtype=float32)

(2) C++接口

// Load the model
std::unique_ptr<tflite::FlatBufferModel> model =
    tflite::FlatBufferModel::BuildFromFile(filename);

// Build the interpreter
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);

// Resize input tensors, if desired.
interpreter->AllocateTensors();

float* input = interpreter->typed_input_tensor<float>(0);
// Fill `input`.

interpreter->Invoke();

float* output = interpreter->typed_output_tensor<float>(0);
发布了80 篇原创文章 · 获赞 74 · 访问量 31万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 黑客帝国 设计师: 上身试试

分享到微信朋友圈

×

扫一扫,手机浏览