一.模型转换
由于笔者最近需要将一个训练好的keras模型在Android中调用,所以最近开始研究如何在Android中调用模型。python的版本为3.6.10,tensorflow的版本为1.8.0,keras的版本为2.1.6。
笔者拿到的是keras的模型,所以首先需要将该模型转化为.pb格式:
from keras import backend as K
from keras import models
from keras.models import Model
from keras.layers import *
import os
import tensorflow as tf
def keras_to_tensorflow(keras_model,output_dir,model_name, out_prefix = "output_",log_tensorboard = True):
if os.path.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(keras_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(keras_model.output[i], out_prefix + str(i + 1))
sess =K.get_session()
from tensorflow.python.framework import graph_util, graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name = model_name, as_text = False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard( os.path.join(output_dir,model_name),output_dir)
"""
We explicitly redefine the Squeezent architecture since Keras has no predefined Squeezenet
"""
def squeezenet_fire_module(input,input_channel_small = 16,input_channel_large = 64):
channel_axis=3
input=Conv2D(input_channel_small,(1, 1),padding = "valid")(input)
input=Activation("relu")(input)
input_branch_1=Conv2D(input_channel_large,(1, 1),padding = "valid")(input)
input_branch_1=Activation("relu")(input_branch_1)
input_branch_2=Conv2D(input_channel_large,(3,3),padding = "same")(input)
input_branch_2=Activation("relu")(input_branch_2)
input=concatenate([input_branch_1,input_branch_2],axis = channel_axis)
return input
def SqueezeNet(input_shape=(224, 224, 3)):
image_input=Input(shape=input_shape)
network=Conv2D(64,(3, 3),strides = (2, 2),padding = "valid")(image_input)
network=Activation("relu")(network)
network=MaxPool2D(pool_size = (3,3),strides = (2, 2))(network)
network=squeezenet_fire_module(input=network,input_channel_small = 16,input_channel_large = 64)
network=squeezenet_fire_module(input=network,input_channel_small = 16,input_channel_large = 64)
network=MaxPool2D(pool_size=(3, 3),strides = (2, 2))(network)
network=squeezenet_fire_module(input=network,input_channel_small = 32,input_channel_large = 128)
network=squeezenet_fire_module(input=network,input_channel_small = 32,input_channel_large = 128)
network=MaxPool2D(pool_size=(3,3),strides = (2,2))(network)
network=squeezenet_fire_module(input=network,input_channel_small = 48,input_channel_large = 192)
network=squeezenet_fire_module(input=network,input_channel_small = 48,input_channel_large = 192)
network=squeezenet_fire_module(input=network,input_channel_small = 64,input_channel_large = 256)
network=squeezenet_fire_module(input=network,input_channel_small = 64,input_channel_large = 256)
# Remove layers like Dropout and BatchNormalization, they are only needed in training
# network = Dropout(0.5)(network)
network=Conv2D(1000,kernel_size = (1, 1),padding = "valid",name = "last_conv")(network)
network=Activation("relu")(network)
networ=GlobalAvgPool2D()(network)
network=Activation("softmax", name="output")(network)
input_image=image_input
model=Model(inputs=input_image,outputs = network)
return model
keras_model=SqueezeNet()
keras_model = models.load_model('Model_lambda0.05_full') #只需将这里改为自己的模型名称
output_dir=os.path.join(os.getcwd(), "checkpoint")
keras_to_tensorflow(keras_model, output_dir=output_dir, model_name="Model.pb") #输出的模型名称
print("MODEL SAVED")
二.移植到Android
1.在Android studio中新建一个Android项目。
2.把训练好的pb文件(Model.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,则新建一个,右键main->new->Directory,输入assets。
3.将现有的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar文件放在libs文件夹下。(下载链接为:https://pan.baidu.com/s/18GbL7zvJLMw8DLf7WqkN-Q
提取码:bgoy)

4.pp\build.gradle配置
multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}
sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
//这里添加libandroid_tensorflow_inference_java.jar包,否则不能解析TensoFlow包
api files('libs/libandroid_tensorflow_inference_java.jar')


三.调用模型
创建PredictionTF.class,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径。getPredict利用训练好的TensoFlow模型预测结果,自己定义即可。
public class PredictionTF {
private static final String TAG = "PredictionTF";
//模型中输入变量的名称
private static final String inputName = "input_1_1";
//模型中输出变量的名称
private static final String outputName = "output_1";
TensorFlowInferenceInterface inferenceInterface;
static {
//加载libtensorflow_inference.so库文件
System.loadLibrary("tensorflow_inference");
Log.e(TAG,"libtensorflow_inference.so库加载成功");
}
PredictionTF(AssetManager assetManager, String modePath) {
//初始化TensorFlowInferenceInterface对象
inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
Log.e(TAG,"TensoFlow模型文件加载成功");
}
/**
* 利用训练好的TensoFlow模型预测结果
*/
public float[] getPredict(float[ ] inputdata) {
//将数据feed给tensorflow的输入节点
inferenceInterface.feed(inputName, inputdata,1,128,3,1); //1,128,3,1代表模型中输入数据格式,这里的输入要和模型里面一样
//运行tensorflow
String[] outputNames = new String[] {outputName};
inferenceInterface.run(outputNames);
///获取输出节点的输出信息
float[] outputs = new float[1*6]; //用于存储模型的输出数据,输出也要和模型里面一样
inferenceInterface.fetch(outputName, outputs);
return outputs;
}
}
创建MainActivity.class,只给出一些主要的代码。
public class MainActivity extends AppCompatActivity {
private static final String TAG = "MainActivity";
private static final String MODEL_FILE = "file:///android_asset/Model.pb"; //模型存放路径
PredictionTF preTF;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// Example of a call to a native method
preTF =new PredictionTF(getAssets(),MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型
}
//在自己的方法中进行调用
public float[] toatal=new float[100];
public ***
{
float[] result= preTF.getPredict(toatal);
}
}
四.获取模型输入输出
到上一步已经完成了整个模型在Android中调用,但是由于笔者得到的只是一个训练好的模型,并不知道该模型里面输入和输出的变量名称,所以又查阅了很多资料。以下代码主要获取的是.pb模型的各个节点的变量名称。
import tensorflow as tf
import os
model_dir = './'
model_name = 'Model.pb'
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name, '\n')
本文只是结合了网上各种方法,简单粗暴的理了一下整个过程,如有错误,谢谢指正,互相学习,共同进步!
本文详细介绍如何将训练好的Keras模型转化为.pb格式,并在Android中调用,包括模型转换、移植步骤及调用方法。
953

被折叠的 条评论
为什么被折叠?



