如何在Android应用中使用TensorFlow Mobile

本文详细介绍了如何在Android Studio项目中使用TensorFlow Mobile来部署机器学习模型。首先,你需要一个训练好的TensorFlow模型,然后冻结模型以便在移动设备上使用。接着,设置Android Studio项目,将TensorFlow Mobile库添加为依赖,并将冻结模型放入项目资产。通过TensorFlow接口初始化模型,并使用模型进行预测。最后,文章鼓励读者尝试更复杂的预训练模型,但要注意这可能增加APK大小。
摘要由CSDN通过智能技术生成

使用TensorFlow (当今最流行的机器学习框架之一),您可以轻松创建和训练深度模型(通常也称为深度前馈神经网络),该模型可以解决各种复杂问题,例如图像分类,对象检测和自然语言理解。 TensorFlow Mobile是一个旨在帮助您在移动应用程序中利用这些模型的库。

在本教程中,我将向您展示如何在Android Studio项目中使用TensorFlow Mobile。

先决条件

要遵循本教程,您需要:

  • Android Studio 3.0或更高版本
  • TensorFlow 1.5.0或更高版本
  • 运行API级别21或更高级别的Android设备
  • 对TensorFlow框架有基本了解

1.建立模型

在开始使用TensorFlow Mobile之前,我们需要一个训练有素的TensorFlow模型。 现在创建一个。

我们的模型将是非常基本的。 它的行为就像XOR门,接收两个输入(两个输入都可以为零或一个),并产生一个输出,如果两个输入相同,则输出为零。 此外,由于它将成为一个深层模型,它将具有两个隐藏层,一个具有四个神经元,另一个具有三个神经元。 您可以自由更改隐藏层的数量及其包含的神经元的数量。

为了使本教程简短,我们将使用TFLearn (一种流行的TensorFlow包装框架,提供更直观,简洁的API),而不是直接使用低级 TensorFlow API。 如果尚未安装,请使用以下命令将其安装在TensorFlow虚拟环境中:

pip install tflearn

要开始创建模型,请创建一个名为create_model.py的Python脚本,最好在一个空目录中,然后使用您喜欢的文本编辑器将其打开。

在文件内部,我们要做的第一件事是导入TFLearn API。

import tflearn

接下来,我们必须创建训练数据。 对于我们的简单模型,将只有四个可能的输入和输出,它们类似于XOR门的真值表的内容。

X = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]

Y = [
    [0],  # Desired output for inputs 0, 0
    [1],  # Desired output for inputs 0, 1
    [1],  # Desired output for inputs 1, 0
    [0]   # Desired output for inputs 1, 1
]

通常最好使用从均匀分布中选取的随机值,同时将初始权重分配给隐藏层中的所有神经元。 要生成值,请使用uniform()方法。

weights = tflearn.initializations.uniform(minval = -1, maxval = 1)

在这一点上,我们可以开始创建神经网络的各层。 要创建输入层,我们必须使用input_data()方法,该方法允许我们指定网络可以接受的输入数量。 输入层准备就绪后,我们可以多次调用fully_connected()方法以向网络添加更多层。

# Input layer
net = tflearn.input_data(
        shape = [None, 2],
        name = 'my_input'
)

# Hidden layers
net = tflearn.fully_connected(net, 4,
        activation = 'sigmoid',
        weights_init = weights
)
net = tflearn.fully_connected(net, 3,
        activation = 'sigmoid',
        weights_init = weights
)

# Output layer
net = tflearn.fully_connected(net, 1,
        activation = 'sigmoid', 
        weights_init = weights,
        name = 'my_output'
)

注意,在上面的代码中,我们为输入和输出层赋予了有意义的名称。 这样做很重要,因为在使用Android应用中的网络时,我们将需要它们。 另请注意,隐藏层和输出层正在使用sigmoid激活功能。 您可以随意尝试其他激活功能,例如softmaxtanhrelu

作为网络的最后一层,我们必须使用regression()函数创建一个回归层,该函数需要一些超参数作为其参数,例如网络的学习率以及应该使用的优化器和损失函数。 下面的代码向您展示如何使用随机梯度下降法(简称SGD)作为优化程序函数,并使用均方作为损失函数:

net = tflearn.regression(net,
        learning_rate = 2,
        optimizer = 'sgd',
        loss = 'mean_square'
)

接下来,为了让TFLearn框架知道我们的网络模型实际上是一个深度神经网络模型,我们必须调用DNN()函数。

model = tflearn.DNN(net)

模型现在准备就绪。 我们现在要做的就是使用我们之前创建的训练数据来训练它。 因此,调用模型的fit()方法,并与训练数据一起指定要运行的训练时期的数量。 由于训练数据非常小,因此我们的模型需要数千个纪元才能达到合理的准确性。

model.fit(X, Y, 5000)

训练完成后,我们可以调用模型的predict()方法来检查其是否正在生成所需的输出。 以下代码显示了如何检查所有有效输入的输出:

print("1 XOR 0 = %f" % model.predict([[1,0]]).item(0))
print("1 XOR 1 = %f" % model.predict([[1,1]]).item(0))
print("0 XOR 1 = %f" % model.predict([[0,1]]).item(0))
print("0 XOR 0 = %f" % model.predict([[0,0]]).item(0))

如果现在运行Python脚本,则应该看到如下所示的输出:

Predictions after training

请注意,输出永远不会完全是0或1。相反,它们是接近零或接近一的浮点数。 因此,在使用输出时,您可能要使用Python的round()函数。

除非我们在训练后明确保存模型,否则脚本结束后我们将立即丢失它。 幸运的是,使用TFLearn,只需调用save()方法即可保存模型。 但是,要能够在TensorFlow Mobile中使用已保存的模型,在保存之前,我们必须确保删除与tf.GraphKeys.TRAIN_OPS集合相关的所有训练相关操作。 以下代码显示了如何执行此操作:

# Remove train ops
with net.graph.as_default():
    del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]

# Save the model
model.save('xor.tflearn')

如果再次运行该脚本,则会看到它生成了一个检查点文件,元数据文件,索引文件和数据文件,将它们一起使用可以快速重新创建我们训练有素的模型。

2.冻结模型

除了保存模型外,我们还必须先冻结模型,然后才能将其用于TensorFlow Mobile。 正如您可能已经猜到的,冻结模型的过程涉及将其所有变量转换为常量。 此外,冻结模型必须是符合Google协议缓冲区序列化格式的单个二进制文件。

创建一个名为freeze_model.py的新Python脚本,然后使用文本编辑器将其打开。 我们将编写所有代码以将模型冻结在此文件中。

由于TFLearn没有冻结模型的任何功能,因此我们现在必须直接使用TensorFlow API。 通过将以下行添加到文件中来导入它们:

import tensorflow as tf

在整个脚本中,我们将使用单个TensorFlow会话。 要创建会话,请使用Session类的构造函数。

with tf.Session() as session:
    # Rest of the code goes here

此时,我们必须通过调用import_meta_graph()函数并将模型的元数据文件的名称传递给它来创建一个Saver对象。 除了返回Saver对象之外, import_meta_graph()函数还将自动将模型的图形定义添加到会话的图形定义中。

创建保护程序后,我们可以通过调用restore()方法初始化图形定义中存在的所有变量,该方法需要包含模型最新检查点文件的目录路径。

my_saver = tf.train.import_meta_graph('xor.tflearn.meta')
my_saver.restore(session, tf.train.latest_checkpoint('.'))

在这一点上,我们可以调用convert_variables_to_constants()函数来创建一个冻结的图定义,其中将模型的所有变量替换为常量。 作为其输入,该函数期望当前会话,当前会话的图形定义以及包含模型输出层名称的列表。

frozen_graph = tf.graph_util.convert_variables_to_constants(
    session,
    session.graph_def,
    ['my_output/Sigmoid']
)

调用冻结图定义的SerializeToString()方法将为我们提供模型的二进制protobuf表示形式。 通过使用Python的基本文件I / O工具,建议您将其保存为名为Frozen_model.pb的文件。

with open('frozen_model.pb', 'wb') as f:
    f.write(frozen_graph.SerializeToString())

您现在可以运行脚本以生成冻结的模型。

现在,我们拥有开始使用TensorFlow Mobile所需的一切。

3. Android Studio项目设置

TensorFlow Mobile库在JCenter上可用,因此我们可以将其作为implementation依赖项直接添加到app模块的build.gradle文件中。

implementation 'org.tensorflow:tensorflow-android:1.7.0'

要将冻结的模型添加到项目中,请将Frozen_model.pb文件放置在项目的资产文件夹中。

4.初始化TensorFlow接口

TensorFlow Mobile提供了一个简单的界面,可用于与冻结模型进行交互。 要创建接口,请使用TensorFlowInferenceInterface类的构造函数,该构造函数需要一个AssetManager实例和冻结模型的文件名。

thread {
    val tfInterface = TensorFlowInferenceInterface(assets,
                                        "frozen_model.pb")
	
    // More code here
}

在上面的代码中,您可以看到我们正在生成一个新线程。 建议这样做(尽管并非总是必要的),以确保应用程序的UI保持响应状态。

为确保TensorFlow Mobile能够正确读取我们的模型文件,现在让我们尝试打印模型图中存在的所有操作的名称。 要获得对图形的引用,我们可以使用接口的graph()方法,并获取所有操作,即图形的operations()方法。 以下代码向您展示了如何:

val graph = tfInterface.graph()
graph.operations().forEach {
    println(it.name())
}

如果您现在运行该应用程序,则应该能够看到在Android Studio的Logcat窗口中打印的十几个操作名称。 在所有这些名称中,如果冻结模型时没有任何错误,则可以找到输入和输出层的名称: my_input / Xmy_output / Sigmoid

Logcat window showing list of operations

5.使用模型

为了对模型进行预测,我们必须将数据放入其输入层,并从其输出层检索数据。 要将数据放入输入层,请使用接口的feed()方法,该方法需要该层的名称,包含输入的数组以及该数组的尺寸。 以下代码显示了如何将数字01发送到输入层:

tfInterface.feed("my_input/X",
            floatArrayOf(0f, 1f), 1, 2)

将数据加载到输入层之后,我们必须使用run()方法运行推断操作,该方法需要输出层的名称。 一旦操作完成,输出层将包含模型的预测。 要将预测加载到Kotlin数组中,我们可以使用fetch()方法。 以下代码显示了如何执行此操作:

tfInterface.run(arrayOf("my_output/Sigmoid"))

val output = floatArrayOf(-1f)
tfInterface.fetch("my_output/Sigmoid", output)

当然,如何使用预测取决于您。 现在,我建议您只打印它。

println("Output is ${output[0]}")

您现在可以运行应用程序,以查看模型的预测是否正确。

Logcat window displaying the prediction

随时更改输入到输入层的数字,以确认模型的预测始终正确。

结论

您现在知道如何创建一个简单的TensorFlow模型并将其与Android应用程序中的TensorFlow Mobile一起使用。 但是,您不必总是将自己局限于自己的模型。 有了您今天所学的技能,使用TensorFlow 模型动物园中可用的较大模型(例如MobileNet和Inception)就不会有问题。 但是请注意,此类模型会导致APK变大,这可能会给使用低端设备的用户带来问题。

要了解有关TensorFlow Mobile的更多信息,请参考官方文档

翻译自: https://code.tutsplus.com/tutorials/how-to-use-tensorflow-mobile-in-android-apps--cms-30957

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值