最近想把在tensorflow上训练的模型移植到安卓上运行,看了一下网上的例子,感觉都很复杂,对于我这种不太会安卓代码的人很不友好,所以决定自己摸索,再看了tensorflow官方的demo后,决定写出了下面这个简易版demo,带你快速了解如何将pb模型移植到安卓上运行。
我的环境:
- windows10
- python3.7
- tensorflow-gpu1.14
- pycharm
- android studio
这个demo并不是针对移植神经网络模型,而是针对pb文件的调用,所以我只写了一段简单的代码来生成pb文件。下面先贴出我的pb文件生成代码。
# -*- coding:utf-8 -*-
# 这是一个基于tensorflow的简单计算,并且保存模型为pd文件
import tensorflow as tf
sess = tf.Session()
matrix1 = tf.placeholder(tf.float32, [2, ], name='input1')
matrix2 = tf.placeholder(tf.float32, [2, ], name='input2')
mat_add = tf.add(matrix1, matrix2, name='output1')
mat_sub = tf.subtract(matrix1, matrix2, name='output2')
res1 = sess.run(mat_add, feed_dict={matrix1: [4, 6], matrix2: [3, 1]})
res2 = sess.run(mat_sub, feed_dict={matrix1: [4, 6], matrix2: [3, 1]})
print("res1=", res1)
print("res2=", res2)
# 保存二进制模型
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def,
output_node_names=['output1', 'output2']) # output_node_names指定要保存哪些输出tensor
with tf.gfile.FastGFile('test.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
sess.close()
这个模型非常简单,两个输入tensor两个输出tensor,一个加法运算OP,一个减法运算OP。这么简单的模型,应该算是非常亲民了吧。
然后我们使用android studio的模板生成一个新的工程,就是那个能够直接打印出hollow world的那个模板。
首先添加一个名为assets资源文件夹到app/src/main/里面,然后把test.pb文件放入assets文件夹中,如下图所示:
之后修改在app/src文件夹下的build.gradle,在“dependencies”里添加(印象中需要翻墙,想不翻墙的话就自己编译库吧):
implementation 'org.tensorflow:tensorflow-android:1.13.1'
用来加入tensorflow的aar库。如下图所示:(可以版本号改为+号,来达到不指定版本的目的)
最后在MainActivity.java中进行修改,调用pb模型并运行和打印结果。
import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
//加入tensorflow支持
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
String MODEL_FILE = "file:///android_asset/test.pb"; //pb文件的位置
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(),MODEL_FILE); //从载入模型
float[] input1 = new float[3]; //一般采用float数组作为输入、输出
float[] input2 = new float[3]; //具体使用什么类型依网络中实际情况而定
input1[0] = (float) 5.0; input1[1] = (float) 6.0; input1[2] = (float) 1.0;
input2[0] = (float) 2.0; input2[1] = (float) 3.0; input2[2] = (float) 2.0;
float[] output1 = new float[3];
float[] output2 = new float[3];
//喂入输入数据,格式为“输入tensor名称(与模型中设定name的一致)”,“输入数据”,“数据的shape”
inferenceInterface.feed("input1", input1, new long[]{1,3});
inferenceInterface.feed("input2", input2, new long[]{1,3});
//执行run,格式为“输出tensor名称(与模型中设定name的一致)”
inferenceInterface.run(new String[]{"output1","output2"});
//获取输出结果,格式为“输出tensor名称(与模型中设定name的一致)”,“输出数据”
inferenceInterface.fetch("output1", output1);
inferenceInterface.fetch("output2", output2);
//打印结果
for(float f : output1)
Log.e("111111", "output1: " + f);
for(float f : output2)
Log.e("111111", "output2: " + f);
}
}
然后直接安装运行就可以看到如下打印结果:
结果正确。
细心地小伙伴应该发现了一个问题,就是在安卓端每个input都输入了3个数,而模型定义的placeholder的输入shape为2个数。似乎tensorflow并没有对这个输入的大小进行判断,具体原因我也不清楚。不过我这个模型过于简单,计算内容的确不会严重依赖于shape,所以还是劝大家严格按照模型约定的大小进行输入。
另外,我还尝试过将python中的placeholder改成变量,name保持不变。它生成的模型依然可以运行出正确的结果,我估计tensorflow只是强制把输入按name进行对应,也不管它的shape或者类型(placeholder,变量等类型)。
最终总结:
1.在安卓端调用pb模型主要依靠TensorFlowInferenceInterface。
2.模型运行与python上类似,使用run指定输出tensor,就可以运行相应的节点得到结果
3.输入输出tensor的name必须严格对应。
4.python上的 sess.run中的feed_dict被作为一个单独的API,在TensorFlowInferenceInterface里的feed。
5.使用TensorFlowInferenceInterface里的fetch得到推理结果。
其实我也是个tensorflow和android的初学者,如果有什么错误,希望大家能够帮我指出,谢谢大家。