Tensorflow实现Android移动端的模型搭建

转自http://www.zhimengzhe.com/Androidkaifa/298339.html

Tensorflow实现Android移动端的模型搭建

概述 随着深度学习的广泛应用和Tensoflow的开源,移动端的模型应用层出不穷。本文介绍了笔者在搭建过程中的一些心得,希望可以帮助到你们。 Mac端Tensorflow CPU版本的安装 如果你现在用的没有太好的GPU,可以安装CPU only的Tensorflow。Linux、Mac系统可以安装Tensorflow的python2和python3版本,Windows系统仅支持python3版本。 安装Tensorflow的依赖库bazel,这个后面要用来生成Tensorflow支持Android的jar包和so库。Mac下用brew安装命令:brew install bazel,或者根据bazel官方文档安装相应的版本; 用pip安装tensorflow的CPU only版本:pip install tensorflow; 验证Tensorflow安装是否成功:

import tensorflowas tf

#显示当前Tensorflow版本号

tf.__version__

生成jar包和so库 将github上的tensorflow下载到本地修改tensorflow目录下的WORKSPACE,将其中的sdk和ndk路径改为本地对应路径,其中的sdk的版本号要≥23,ndk的版本号建议是12b(高版本的ndk在用bazel编译时会出现一些问题),build_tools_version版本根据自己实际情况更改: 
 根据下面的指令生成jar包和so库 
参考链接: 
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/android

生成so库的命令:可以选择cpu版本号

bazel build -c opt//tensorflow/contrib/android:libtensorflow_inference.so \

  --crosstool_top=//external:android/crosstool\

 --host_crosstool_top=@bazel_tools//tools/cpp:toolchain\

  --cpu=armeabi-v7a

so库的位置:

bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

生成jar包的命令:

bazel build//tensorflow/contrib/android:android_tensorflow_inference_java

jar包的位置:

bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

4. so库也可以选择Tensorflow官方提供的现成的文件,参考链接: 
http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/libtensorflow_inference.so/

Android端的搭建 将jar包放入app->libs目录下,并在build.gradle上添加依赖compilefiles('libs/libandroid_tensorflow_inference_java.jar'); 在src->main目录下新建文件夹jniLibs并将生成的so库放在该目录下; 将PC端训练的模型保存为pb模型

output_graph_def= \

graph_util.convert_variables_to_constants(sess,\

sess.graph_def,output_node_names=['output'])

withtf.gfile.FastGFile("path/to/xxx.pb","wb") as f:

    f.write(output_graph_def.SerializeToString())

 

4.将pb文件放在src->main->assets目录下; 
5. 下面我们用Android代码解释一下Tensorflow在Android端的搭建

public classTensorFlowAudioClassifier implements Classifier{

 

    private static final String TAG ="TensorFlowAudioClassifier";

 

    // Only return this many results with atleast this confidence.

    private static final int MAX_RESULTS = 3;

    private static final float THRESHOLD =0.0f;

 

    // Config values.

    //输入节点的名称(不带后面的':0',只是input的名称,如'input'

    private String inputName;

    //输出节点的名称(通输入节点名称一样)

    private String outputName;

    //输入矩阵的大小(因为一般是方形矩阵,这儿是方形矩阵的size

    private int inputSize;

 

    // Pre-allocated buffers.

    private Vector<String> labels = newVector<String>();

    private float[] floatValues;

    private float[] outputs;

    private String[] outputNames;

 

    private TensorFlowInferenceInterfaceinferenceInterface;

    //这儿使用单例模式

 

    private TensorFlowAudioClassifier() {

    }

 

    /**

     * Initializes a native TensorFlow sessionfor classifying images.

     *

     * @param assetManager  The asset manager to be used to load assets.

     * @param modelFilename The filepath of themodel GraphDef protocol buffer.

     * @param labelFilename The filepath oflabel file for classes.

     * @param inputSize     The input size. A square image ofinputSize x inputSize is assumed.

     * @param inputName     The label of the image input node.

     * @param outputName    The label of the output node.

     * @throws IOException

     */

    public static Classifier create(

            AssetManager assetManager,

            String modelFilename,

            String labelFilename,

            int inputSize,

            String inputName,

            String outputName)

            throws IOException {

        TensorFlowAudioClassifier c = newTensorFlowAudioClassifier();

        c.inputName = inputName;

        c.outputName = outputName;

 

        // Read the label names into memory.

        // TODO(andrewharp): make this handlenon-assets.

        //获取label文件,后面可以用来构建bean

        String actualFilename =labelFilename.split("file:///android_asset/")[1];

        Log.i(TAG, "Reading labels from:" + actualFilename);

        BufferedReader br = null;

        br = new BufferedReader(new InputStreamReader(assetManager.open(actualFilename)));

        String line;

        while ((line = br.readLine()) != null){

            c.labels.add(line);

        }

        br.close();

 

        c.inferenceInterface = newTensorFlowInferenceInterface();

       if (c.inferenceInterface.initializeTensorFlow(assetManager,modelFilename) != 0) {

            throw new RuntimeException("TFinitialization failed");

        }

        // The shape of the output is [N,NUM_CLASSES], where N is the batch size.

        int numClasses =

                (int)c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);

        Log.i(TAG, "Read " +c.labels.size() + " labels, output layer size is " + numClasses);

 

        // Ideally, inputSize could have beenretrieved from the shape of the input operation.  Alas,

        // the placeholder node for input inthe graphdef typically used does not specify a shape, so it

        // must be passed in as a parameter.

        c.inputSize = inputSize;

 

        // Pre-allocate buffers.

        c.outputNames = newString[]{outputName};

        c.floatValues = new float[inputSize *inputSize * 1];

        c.outputs = new float[numClasses];

 

        return c;

    }

 

    @Override

    // 识别过程

    public List<Recognition>recognizeAudio(String fileName) {

        // Log this method so that it can beanalyzed with systrace.

       Trace.beginSection("recognizeAudio");

 

       Trace.beginSection("preprocessAudio");

        // Preprocess the audio data to normalizedfloat based

        // on the provided parameters.

        // 将音频文件构建成输入数组,输入数组是一维float数组,所以在Tensorflow中的特征要转变为一维

        double[][] data =RFFT.inputData(fileName);

        for (int i = 0; i < data.length;++i) {

            for (int j = 0 ; j <data[0].length ; ++j) {

                floatValues[i * 40 + j] =(float)data[i][j];

            }

        }

        Trace.endSection();

 

        // Copy the input data into TensorFlow.

       Trace.beginSection("fillNodeFloat");

        //将输入放入InferenceInterface

        inferenceInterface.fillNodeFloat(

                inputName, newint[]{1,40,40,1}, floatValues);

        Trace.endSection();

 

//     Trace.beginSection("fillNodeFloat");

//     inferenceInterface.fillNodeFloat(inputName2,new int[]{2,2},floatValues);

//      Trace.endSection();

 

        // Run the inference call.

       Trace.beginSection("runInference");

        // 运行模型

       inferenceInterface.runInference(outputNames);

        Trace.endSection();

 

        // Copy the output Tensor back into theoutput array.

       Trace.beginSection("readNodeFloat");

        // 获得输出的confidenceoutputs数组里

       inferenceInterface.readNodeFloat(outputName, outputs);

        Trace.endSection();

 

        // Find the best classifications.

        // PriorityQueue获取top-3,这儿的Recognition来自于接口Classifier,是一个bean

        PriorityQueue<Recognition> pq =

                newPriorityQueue<Recognition>(

                        3,

                        newComparator<Recognition>() {

                            @Override

                            public intcompare(Recognition lhs, Recognition rhs) {

                                //Intentionally reversed to put high confidence at the head of the queue.

                                returnFloat.compare(rhs.getConfidence(), lhs.getConfidence());

                            }

                        });

        for (int i = 0; i < outputs.length;++i) {

            if (outputs[i] > THRESHOLD) {

                pq.add(

        //构建bean类,参数是labellabel nameconfidence

                        new Recognition(

                                "" +i, labels.size() > i ? labels.get(i) : "unknown", outputs[i]));

            }

        }

        final ArrayList<Recognition>recognitions = new ArrayList<Recognition>();

        int recognitionsSize =Math.min(pq.size(), MAX_RESULTS);

        if(recognitionsSize == 0){

            return null;

        }

        for (int i = 0; i <recognitionsSize; ++i) {

            recognitions.add(pq.poll());

        }

        Trace.endSection(); //"recognizeAudio"

        return recognitions;

    }

 

    @Override

    public void enableStatLogging(booleandebug) {

       inferenceInterface.enableStatLogging(debug);

    }

 

    @Override

    public String getStatString() {

        returninferenceInterface.getStatString();

    }

 

    @Override

    public void close() {

        inferenceInterface.close();

    }

}

我们成功地构建了用来识别分类的TensorFlowAudioClassifier,下面展示一下如何使用我们构建的类:

    public static TensorFlowAudioClassifierclassifier;

    private static final String INPUT_NAME ="input";

    private static final String OUTPUT_NAME ="output";

 

    private static final String MODEL_FILE ="file:///android_asset/acoustic.pb";

    private static final String LABEL_FILE =

           "file:///android_asset/eventLabel.txt";

    private static int INPUT_SIZE = 40;

    try {

    // 获取分类器

            classifier =(TensorFlowAudioClassifier) TensorFlowAudioClassifier.create(

                    getAssets(),

                    MODEL_FILE,

                    LABEL_FILE,

                    INPUT_SIZE,

                    INPUT_NAME,

                    OUTPUT_NAME

            );

            // 识别对应Audio的类别

            Recognition result = classifier.recognizeAudio(fileName);

        } catch (IOException e) {

            e.printStackTrace();

        }

 

总结 现在我们已经成功构建了Android端的TensorFlow深度模型。当然我们的思路不是在Android端训练,而只是使用(Android端训练有点异想天开,毕竟训练在PC端都不是很容易实现的,一些层次深的模型还要借助分布式训练)。

如果我的博客对你有帮助,请记得分享给他人。

以上就是Tensorflow实现Android移动端的模型搭建的全文介绍,希望对您学习Android应用开发有所帮助.

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值