TensorFlow Lite 是用于移动设备和嵌入式设备的轻量级解决方案。TensorFlow Lite 支持 Android、iOS 甚至树莓派等多种平台。
我们知道大多数的 AI 是在云端运算的,但是在移动端使用 AI 具有无网络延迟、响应更加及时、数据隐私等特性。
对于离线的场合,云端的 AI 就无法使用了,而此时可以在移动设备中使用 TensorFlow Lite。
二. tflite 格式
TensorFlow 生成的模型是无法直接给移动端使用的,需要离线转换成.tflite文件格式。
tflite 存储格式是 flatbuffers。
FlatBuffers 是由Google开源的一个免费软件库,用于实现序列化格式。它类似于Protocol Buffers、Thrift、Apache Avro。
因此,如果要给移动端使用的话,必须把 TensorFlow 训练好的 protobuf 模型文件转换成 FlatBuffers 格式。官方提供了 toco 来实现模型格式的转换。
三. 常用的 Java API
TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。无论哪种 API 都需要加载模型和运行模型。
而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。后面的例子会看到如何使用 Interpreter。
四. TensorFlow Lite + mnist 数据集实现识别手写数字
mnist 是手写数字图片数据集,包含60000张训练样本和10000张测试样本。 测试集也是同样比例的手写数字数据。每张图片有28x28个像素点构成,每个像素点用一个灰度值表示,这里是将28x28的像素展开为一个一维的行向量(每行784个值)。
mnist 数据集获取地址:http://yann.lecun.com/exdb/mnist/
下面的 demo 中已经包含了 mnist.tflite 模型文件。(如果没有的话,需要自己训练保存成pb文件,再转换成tflite 格式)
对于一个识别类,首先需要初始化 TensorFlow Lite 解释器,以及输入、输出。
// The tensorflow lite file
private lateinit var tflite: Interpreter
// Input byte buffer
private lateinit var inputBuffer: ByteBuffer
// Output array [batch_size, 10]
private lateinit var mnistOutput: Array<FloatArray>
init {
try {
tflite = Interpreter(loadModelFile(activity))
inputBuffer = ByteBuffer.allocateDirect(
BYTE_SIZE_OF_FLOAT * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE)
inputBuffer.order(ByteOrder.nativeOrder())
mnistOutput = Array(DIM_BATCH_SIZE)