TensorFlow的android移植

tensorflow克隆到本地

git clone --recurse-submodules https: //github.com/tensorflow/tensorflow.git

1.安装bazel.
bazeltensorflow工程的主要构建工具
下载bazel-0.5.4-installer-linux-x86_64.sh,执行即可
bazel的简介
2.下载NDK
最好下载 r12b版本的,最新的 r13b可能与 bazel有兼容问题。
3. 编辑Tensorflow根目录下的WORKSPACE文件
在WORKSPACE文件开头部分,修改android_ndk_repository部分
# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
#    name = "androidsdk",
#    api_level = 23,
#    # Ensure that you have the build_tools_version below installed in the
#    # SDK manager as it updates periodically.
#    build_tools_version = "25.0.2",
#    # Replace with path to Android SDK on your system
#    path = "/path/to/your/sdk",
#)

# Android NDK r12b is recommended (higher may cause issues with Bazel)
#android_ndk_repository(
#    name="androidndk",
#    path="/path/to/your/ndk",
#    # This needs to be 14 or higher to compile TensorFlow.
#    # Note that the NDK version is not the API level.
#    api_level=14)
这两部分定义了 SDK NDK 的路径。把 /path/to/your 的部分改成相应的路径,然后将每一行前的注释去掉。
因为只需编译TensorFlowLite.so,不需要SDK,只需要NDK

5.编译
生成so库的命令:
bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so \
  --crosstool_top=//external:android/crosstool \
  --host_crosstool_top=@bazel_tools//tools/cpp:toolchain\
  --cpu=armeabi-v7a
生成的so库的位置:
bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

生成libtensorflowlite_jni.so
bazel build -c opt --cxxopt='--std=c++11'  //tensorflow/contrib/lite/java:libtensorflowlite_jni.so

PC端模型的准备

这是一个很简单的模型,输入是一个数组matrix1,经过操作后,得到这个数组乘以2*matrix1。

  1. 给输入数据命名为input,在android端需要用这个input来为输入数据赋值
  2. 给输出数据命名为output,在android端需要用这个output来为获取输出的值
  3. 不能使用 tf.train.write_graph()保存模型,因为它只是保存了模型的结构,并不保存训练完毕的参数值
  4. 不能使用 tf.train.saver()保存模型,因为它只是保存了网络中的参数值,并不保存模型的结构。
  5. graph_util.convert_variables_to_constants可以把整个sesion当作常量都保存下来,通过output_node_names参数来指定输出
  6. tf.gfile.FastGFile('model/cxq.pb', mode='wb')指定保存文件的路径以及读写方式
  7. f.write(output_graph_def.SerializeToString())将固化的模型写入到文件
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.client import graph_util

session = tf.Session()

matrix1 = tf.constant([[3., 3.]], name='input')
add2Mat = tf.add(matrix1, matrix1, name='output')

session.run(add2Mat)

output_graph_def = graph_util.convert_variables_to_constants(session, session.graph_def,output_node_names=['output'])

with tf.gfile.FastGFile('model/cxq.pb', mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

session.close()

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

运行后就会在model文件夹下产生一个cxq.pb文件,现在这个文件将刚才一系列的操作固化了,因此下次需要计算变量乘2时,我们可以直接拿到pb文件,指定输入,再获取输出

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值