本教程介绍如何使用 tf.Keras 时序 API 从头开始训练模型,将 tf.Keras 模型转换为 tflite 格式,并在 Android 上运行该模型。我将以 MNIST 数据为例介绍图像分类,并分享一些你可能会面临的常见问题。本教程着重于端到端的体验,我不会深入探讨各种 tf.Keras API 或 Android 开发。
下载我的示例代码并执行以下操作:
在 colab 中运行:使用 tf.keras 的训练模型,并将 keras 模型转换为 tflite(链接到 Colab notebook)。
在 Android Studio 中运行:DigitRecognizer(链接到 Android 应用程序)。
1. 训练自定义分类器
加载数据
我们将使用作为 tf.keras 框架一部分的 mnst 数据。
( x_train, y_train ) , ( x_test, y_test ) = keras.datasets.mnist.load_data ( )
预处理数据
接下来,我们将输入图像从 28x28 变为 28x28x1 的形状,将其标准化,并对标签进行 one-hot 编码。
定义模型体系结构
然后我们将用 cnn 定义网络架构。
def create_model ( ) :
# Define the model architecture
model = keras.models.Sequential ( [
# Must define the input shape in the first layer of the neural network
keras.layers.Conv2D ( filters=32,