将训练好的keras模型在Andorid中调用

一.模型转换

由于笔者最近需要将一个训练好的keras模型在Android中调用,所以最近开始研究如何在Android中调用模型。python的版本为3.6.10,tensorflow的版本为1.8.0,keras的版本为2.1.6。

笔者拿到的是keras的模型,所以首先需要将该模型转化为.pb格式:


from keras import backend as K
from keras import models
from keras.models import Model
from keras.layers import *
import os
import tensorflow as tf



def keras_to_tensorflow(keras_model,output_dir,model_name, out_prefix = "output_",log_tensorboard = True):

    if os.path.exists(output_dir) == False:
         os.mkdir(output_dir)

    out_nodes = []

    for i in  range(len(keras_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(keras_model.output[i], out_prefix + str(i + 1))

    sess =K.get_session()

    from tensorflow.python.framework import graph_util, graph_io




    init_graph = sess.graph.as_graph_def()

    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

    graph_io.write_graph(main_graph, output_dir, name = model_name, as_text = False)

    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard

        import_pb_to_tensorboard.import_to_tensorboard(  os.path.join(output_dir,model_name),output_dir)


    """
We explicitly redefine the Squeezent architecture since Keras has no predefined Squeezenet
"""


def squeezenet_fire_module(input,input_channel_small = 16,input_channel_large = 64):

    channel_axis=3

    input=Conv2D(input_channel_small,(1, 1),padding = "valid")(input)
    input=Activation("relu")(input)

    input_branch_1=Conv2D(input_channel_large,(1, 1),padding = "valid")(input)
    input_branch_1=Activation("relu")(input_branch_1)

    input_branch_2=Conv2D(input_channel_large,(3,3),padding = "same")(input)
    input_branch_2=Activation("relu")(input_branch_2)

    input=concatenate([input_branch_1,input_branch_2],axis = channel_axis)

    return input


def SqueezeNet(input_shape=(224, 224, 3)):

    image_input=Input(shape=input_shape)


    network=Conv2D(64,(3, 3),strides = (2, 2),padding = "valid")(image_input)
    network=Activation("relu")(network)
    network=MaxPool2D(pool_size = (3,3),strides = (2, 2))(network)

    network=squeezenet_fire_module(input=network,input_channel_small = 16,input_channel_large = 64)
    network=squeezenet_fire_module(input=network,input_channel_small = 16,input_channel_large = 64)
    network=MaxPool2D(pool_size=(3, 3),strides = (2, 2))(network)

    network=squeezenet_fire_module(input=network,input_channel_small = 32,input_channel_large = 128)
    network=squeezenet_fire_module(input=network,input_channel_small = 32,input_channel_large = 128)
    network=MaxPool2D(pool_size=(3,3),strides = (2,2))(network)

    network=squeezenet_fire_module(input=network,input_channel_small = 48,input_channel_large = 192)
    network=squeezenet_fire_module(input=network,input_channel_small = 48,input_channel_large = 192)
    network=squeezenet_fire_module(input=network,input_channel_small = 64,input_channel_large = 256)
    network=squeezenet_fire_module(input=network,input_channel_small = 64,input_channel_large = 256)

     # Remove layers like Dropout and BatchNormalization, they are only needed in training
    # network = Dropout(0.5)(network)

    network=Conv2D(1000,kernel_size = (1, 1),padding = "valid",name = "last_conv")(network)
    network=Activation("relu")(network)

    networ=GlobalAvgPool2D()(network)
    network=Activation("softmax", name="output")(network)


    input_image=image_input
    model=Model(inputs=input_image,outputs = network)

    return model

keras_model=SqueezeNet()

keras_model = models.load_model('Model_lambda0.05_full')  #只需将这里改为自己的模型名称

output_dir=os.path.join(os.getcwd(), "checkpoint")

keras_to_tensorflow(keras_model, output_dir=output_dir, model_name="Model.pb") #输出的模型名称

print("MODEL SAVED")

二.移植到Android

1.在Android studio中新建一个Android项目。

2.把训练好的pb文件(Model.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,则新建一个,右键main->new->Directory,输入assets。

3.将现有的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar文件放在libs文件夹下。(下载链接为:https://pan.baidu.com/s/18GbL7zvJLMw8DLf7WqkN-Q
提取码:bgoy)
在这里插入图片描述
4.pp\build.gradle配置

       multiDexEnabled true
        ndk {
            abiFilters "armeabi-v7a"
        }
    sourceSets {
        main {
            jniLibs.srcDirs = ['libs']
        }
    }
    //这里添加libandroid_tensorflow_inference_java.jar包,否则不能解析TensoFlow包
    api files('libs/libandroid_tensorflow_inference_java.jar')

在这里插入图片描述
在这里插入图片描述

三.调用模型

创建PredictionTF.class,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径。getPredict利用训练好的TensoFlow模型预测结果,自己定义即可。


public class PredictionTF {
    private static final String TAG = "PredictionTF";
    //模型中输入变量的名称
    private static final String inputName = "input_1_1";
    //模型中输出变量的名称
    private static final String outputName = "output_1";

    TensorFlowInferenceInterface inferenceInterface;
    static {
        //加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
        Log.e(TAG,"libtensorflow_inference.so库加载成功");
    }

    PredictionTF(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e(TAG,"TensoFlow模型文件加载成功");
    }

    /**
     *  利用训练好的TensoFlow模型预测结果
     */
    public float[] getPredict(float[ ] inputdata) {


        //将数据feed给tensorflow的输入节点
        inferenceInterface.feed(inputName, inputdata,1,128,3,1); //1,128,3,1代表模型中输入数据格式,这里的输入要和模型里面一样
        
        //运行tensorflow
        String[] outputNames = new String[] {outputName};
        inferenceInterface.run(outputNames);
        ///获取输出节点的输出信息
        float[] outputs = new float[1*6]; //用于存储模型的输出数据,输出也要和模型里面一样
        inferenceInterface.fetch(outputName, outputs);

        return outputs;
    }


}

创建MainActivity.class,只给出一些主要的代码。



public class MainActivity extends AppCompatActivity {


    private static final String TAG = "MainActivity";

    private static final String MODEL_FILE = "file:///android_asset/Model.pb"; //模型存放路径
  
  
    PredictionTF preTF;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        // Example of a call to a native method

      
        preTF =new PredictionTF(getAssets(),MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型

    }

 	//在自己的方法中进行调用
    public float[] toatal=new float[100];
    public  ***
    {

        float[] result= preTF.getPredict(toatal);

    }
}

四.获取模型输入输出

到上一步已经完成了整个模型在Android中调用,但是由于笔者得到的只是一个训练好的模型,并不知道该模型里面输入和输出的变量名称,所以又查阅了很多资料。以下代码主要获取的是.pb模型的各个节点的变量名称。

import tensorflow as tf
import os

model_dir = './'
model_name = 'Model.pb'

def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name, '\n')

本文只是结合了网上各种方法,简单粗暴的理了一下整个过程,如有错误,谢谢指正,互相学习,共同进步!

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
使用TensorFlow调用Keras模型,需要以下步骤: 1. 加载模型 Keras模型通常是通过HDF5格式保存的,可以使用Tensorflow的`keras.models.load_model`方法来加载模型。例如: ``` import tensorflow as tf from tensorflow import keras # 加载模型 model = keras.models.load_model('/path/to/model.h5') ``` `/path/to/model.h5`是模型所在的文件路径。 2. 运行模型 加载模型后,就可以使用模型进行推断了。例如: ``` import tensorflow as tf from tensorflow import keras import numpy as np # 加载模型 model = keras.models.load_model('/path/to/model.h5') # 输入数据 input_data = np.zeros((1, 224, 224, 3), dtype=np.float32) # 运行模型 output_data = model.predict(input_data) # 输出结果 print(output_data) ``` `input_data`是输入数据,`output_data`是输出结果。`model.predict`方法用于对输入数据进行推断,返回输出结果。 需要注意的是,Keras模型在加载时需要先创建一个TensorFlow的session,可以使用`tf.Session()`方法来创建。例如: ``` import tensorflow as tf from tensorflow import keras import numpy as np # 创建session sess = tf.Session() # 加载模型 model = keras.models.load_model('/path/to/model.h5') # 设置session keras.backend.set_session(sess) # 输入数据 input_data = np.zeros((1, 224, 224, 3), dtype=np.float32) # 运行模型 output_data = model.predict(input_data) # 输出结果 print(output_data) # 关闭session sess.close() ``` 在使用完后需要关闭session,释放资源。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值