官方给出的demo中运行已经打包好的模型,没有解释怎样从零开始构建自己的模型。参考网站https://omid.al/posts/2017-02-20-Tutorial-Build-Your-First-Tensorflow-Android-App.html,自己做了一些尝试。
准备我们自己的TF模型
首先,我们创建一个简单的模型,把它的计算图保存为一个序列化的GraphDef文件。训练之后,把模型的变量值保存到checkpoint文件中。最后,我们需要把这两个文件变成一个优化了的独立的文件,这个文件是我们在Android App中所需要的所有文件。
创建和保存模型
主要目的是演示过程,所以模型十分简单:一个采用ReLU的单层网络。代码如下:
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'tfdroid.pbtxt')
# normally you would do some training here
# but fornow we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'tfdroid.ckpt')
运行上面的代码会把模型的计算图保存在tfdroid.pbtxt文件中,同时把模型变量的checkpoint保存在tfdroid.ckpt中。
冻结图
接下来需要把checkpoint中的变量