android工程做成框架,TensorFlow集成Android工程的框架

欢迎Follow我的GitHub,关注我的简书

在Android工程中,集成TensorFlow模型。运行TensorFlow的默认Android工程,请参考。

库及模型的大小

libtensorflow_inference.so 10.2 M

libandroid_tensorflow_inference_java.jar 27 KB

optimized_tfdroid.pb 291 B

如果将so转换为jar库,参考,则TF的so由10.2M缩小至4.1M。

870e9a54749a

TF Android

TensorFlow

创建TensorFlow模型,简单的y=WX+b,存储图信息write_graph,存储参数信息saver.save。输入数据placeholder是I,输出数据是O。

import tensorflow as tf

I = tf.placeholder(tf.float32, shape=[None, 3], name='I') # input

W = tf.Variable(tf.zeros(shape=[3, 2]), dtype=tf.float32, name='W') # weights

b = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biases

O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output

saver = tf.train.Saver()

init_op = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init_op)

tf.train.write_graph(sess.graph_def, './data/android/', 'tfdroid.pbtxt') # 存储TensorFlow的图

# 训练数据,本例直接赋值

sess.run(tf.assign(W, [[1, 2], [4, 5], [7, 8]]))

sess.run(tf.assign(b, [1, 1]))

# 存储checkpoint文件,即参数信息

saver.save(sess, './data/android/tfdroid.ckpt')

创建Freeze的图,将图结构与参数组合在一起,生成模型,参考。

def gnr_freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,

output_node_names, output_graph, clear_devices):

"""

将输入图与参数结合在一起

:param input_graph: 输入图

:param input_saver: Saver解析器

:param input_binary: 输入图的格式,false是文本,true是二进制

:param input_checkpoint: checkpoint,检查点文件

:param output_node_names: 输出节点名称

:param output_graph: 保存输出文件

:param clear_devices: 清除训练设备

:return: NULL

"""

restore_op_name = "save/restore_all"

filename_tensor_name = "save/Const:0"

freeze_graph.freeze_graph(

input_graph=input_graph, # 输入图

input_saver=input_saver, # Saver解析器

input_binary=input_binary, # 输入图的格式,false是文本,true是二进制

input_checkpoint=input_checkpoint, # checkpoint,检查点文件

output_node_names=output_node_names, # 输出节点名称

restore_op_name=restore_op_name, # 从模型恢复节点的名字

filename_tensor_name=filename_tensor_name, # tensor名称

output_graph=output_graph, # 保存输出文件

clear_devices=clear_devices, # 清除训练设备

initializer_nodes="") # 初始化节点

优化模型,剪切节点,模型只保留输入输出的参数。

def gnr_optimize_graph(graph_path, optimized_graph_path):

"""

优化图

:param graph_path: 原始图

:param optimized_graph_path: 优化的图

:return: NULL

"""

input_graph_def = tf.GraphDef() # 读取原始图

with tf.gfile.Open(graph_path, "r") as f:

data = f.read()

input_graph_def.ParseFromString(data)

# 设置输入输出节点,剪切分支,大约节省1/4

output_graph_def = optimize_for_inference_lib.optimize_for_inference(

input_graph_def,

["I"], # an array of the input node(s)

["O"], # an array of output nodes

tf.float32.as_datatype_enum)

# 存储优化的图

f = tf.gfile.FastGFile(optimized_graph_path, "w")

f.write(output_graph_def.SerializeToString())

执行函数,生成模型,frozen_tfdroid.pb和optimized_tfdroid.pb。

if __name__ == "__main__":

input_graph_path = MODEL_FOLDER + MODEL_NAME + '.pbtxt' # 输入图

checkpoint_path = MODEL_FOLDER + MODEL_NAME + '.ckpt' # 输入参数

output_path = MODEL_FOLDER + 'frozen_' + MODEL_NAME + '.pb' # Freeze模型

gnr_freeze_graph(input_graph=input_graph_path, input_saver="",

input_binary=False, input_checkpoint=checkpoint_path,

output_node_names="O", output_graph=output_path, clear_devices=True)

optimized_output_graph = MODEL_FOLDER + 'optimized_' + MODEL_NAME + '.pb'

gnr_optimize_graph(output_path, optimized_output_graph)

Android

编译Android的库,参考,或者,直接在Nightly中下载,参考,archive.zip,大约158M。

创建Android工程,添加app/libs/中添加库文件。

armeabi-v7a/libtensorflow_inference.so

libandroid_tensorflow_inference_java.jar

在build.gradle中,添加

android {

sourceSets {

main {

jniLibs.srcDirs = ['libs']

}

}

}

在app/src/main/assets中,添加模型optimized_tfdroid.pb文件。

在MainActivity中,添加so库。

static {

System.loadLibrary("tensorflow_inference");

}

模型文件在assets中,TF的核心接口类TensorFlowInferenceInterface。

private static final String MODEL_FILE = "file:///android_asset/optimized_tfdroid.pb";

private TensorFlowInferenceInterface mInferenceInterface;

初始模型文件

mInferenceInterface = new TensorFlowInferenceInterface();

mInferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);

模型Feed数据,输入点名称是INPUT_NODE,输入结构INPUT_SIZE,输入数据inputFloats。

float[] inputFloats = {num1, num2, num3};

mInferenceInterface.fillNodeFloat(INPUT_NODE, INPUT_SIZE, inputFloats);

模型执行文件,输出点名称是OUTPUT_NODE,即"O"

mInferenceInterface.runInference(new String[]{OUTPUT_NODE});

输出数据结构

float[] resu = {0, 0};

mInferenceInterface.readNodeFloat(OUTPUT_NODE, resu);

最后,在layout中创建GUI布局。

效果

870e9a54749a

Demo

TensorFlow集成至春雨医生

870e9a54749a

CY-TF

That's all! Enjoy it!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值