CV深度学习模型Android端落地方案之三:使用Tensorflow Lite 将自己的训练得到的模型移植到Android上

这个系列的博客主要介绍如何在Android设备上移植你训练的cv神经网络模型。

主要过程如下:
1、使用Android Camera2 APIs获得摄像头实时预览的画面。
2、如果是对人脸图像进行处理,使用Android Camera2自带的Face类来对人脸检测,并完成在预览画面上画框将人脸框出、添加文字显示神经网络处理结果的功能。
3、使用Tensorflow Lite 将自己的训练得到的模型移植到Android上。

以上三个步骤会分为三个博客,同时也会提供示例代码。步骤二可以根据你的实际需求跳过或修改。这是这个系列的第三篇博客。

转载连接

准备

  • 一个你训练好的以pb结尾的Tensorflow模型,如果你的模型是caffemodel可以使用代码将其转换成pb模型。
  • 一个能调用你的模型完成你想要功能的Python脚本,以确保你的模型可以使用,以手写字体模型Minist为例:
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
 
#模型路径
model_path = 'model/mnist.pb'
#测试图片
testImage = Image.open("data/test_image.jpg");
 
with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    with open(model_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
 
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # x_test = x_test.reshape(1, 28 * 28)
        input_x = sess.graph.get_tensor_by_name("input/x_input:0")
        output = sess.graph.get_tensor_by_name("output:0")
 
        #对图片进行测试
        testImage=testImage.convert('L')
        testImage = testImage.resize((28, 28))
        test_input=np.array(testImage)
        test_input = test_input.reshape(1, 28 * 28)
        pre_num = sess.run(output, feed_dict={input_x: test_input})#利用训练好的模型预测结果
        print('模型预测结果为:',pre_num)
        #显示测试的图片
        # testImage = test_x.reshape(28, 28)
        fig = plt.figure(), plt.imshow(testImage,cmap='binary')  # 显示图片
        plt.title("prediction result:"+str(pre_num))
        plt.show()

  • 清楚你的输入lable(inputName)和输出lable(outputName),如果不清楚,可以使用如下代码输出pb模型的层级结构
# coding:utf-8
import tensorflow as tf

#输出保存的模型中参数名字及对应的值
with tf.gfile.GFile('ResNet_model.pb', "rb") as f:  #读取模型数据
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) #得到模型中的计算图和数据
with tf.Graph().as_default() as graph:  # 这里的Graph()要有括号,不然会报TypeError
    tf.import_graph_def(graph_def, name="")  #导入模型中的图到现在这个新的计算图中,不指定名字的话默认是 import
    for op in graph.get_operations():  # 打印出图中的节点信息
        print(op.name, op.values())

  • 清楚你的输入向量和输出向量的关系
输入输出IN_COLIN_ROWOUT_COLOUT_ROW代码
输入:单通道28*28输出:1*1128*2811inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
输入:三通道224*224输出:1*73224*22417inferenceInterface.feed(inputName, inputdata, (1,3,224,224));

Android Studio配置

(1)新建一个Android项目。
(2)把训练好的pb文件(mnist.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。
(3)将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下。

在这里插入图片描述

(4)app\build.gradle配置

在defaultConfig中添加

multiDexEnabled true
    ndk {
        abiFilters "armeabi-v7a"
    }

增加sourceSets

sourceSets {
    main {
        jniLibs.srcDirs = ['libs']
    }
}

在这里插入图片描述

在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar:

compile files('libs/libandroid_tensorflow_inference_java.jar')

在这里插入图片描述

代码

在需要调用TensoFlow的地方,加载so库“System.loadLibrary(“tensorflow_inference”);并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了

注意,旧版的TensoFlow,是如下方式进行,该方法可参考大神的博客:https://www.jianshu.com/p/1168384edc1e

TensorFlowInferenceInterface.fillNodeFloat(); //送入输入数据
TensorFlowInferenceInterface.runInference();  //进行模型的推理
TensorFlowInferenceInterface.readNodeFloat(); //获取输出数据

但在最新的libandroid_tensorflow_inference_java.jar中,已经没有这些方法了,换为

TensorFlowInferenceInterface.feed()
TensorFlowInferenceInterface.run()
TensorFlowInferenceInterface.fetch()

下面是以MNIST手写数字识别为例,其实现方法如下:

package com.example.jinquan.pan.mnist_ensorflow_androiddemo;
 
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;
 
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
 
 
public class PredictionTF {
    private static final String TAG = "PredictionTF";
    //设置模型输入/输出节点的数据维度
    private static final int IN_COL = 1;
    private static final int IN_ROW = 28*28;
    private static final int OUT_COL = 1;
    private static final int OUT_ROW = 1;
    //模型中输入变量的名称
    private static final String inputName = "input/x_input";
    //模型中输出变量的名称
    private static final String outputName = "output";
 
    TensorFlowInferenceInterface inferenceInterface;
    static {
        //加载libtensorflow_inference.so库文件
        System.loadLibrary("tensorflow_inference");
        Log.e(TAG,"libtensorflow_inference.so库加载成功");
    }
 
    PredictionTF(AssetManager assetManager, String modePath) {
        //初始化TensorFlowInferenceInterface对象
        inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
        Log.e(TAG,"TensoFlow模型文件加载成功");
    }
 
    /**
     *  利用训练好的TensoFlow模型预测结果
     * @param bitmap 输入被测试的bitmap图
     * @return 返回预测结果,int数组
     */
    public int[] getPredict(Bitmap bitmap) {
        float[] inputdata = bitmapToFloatArray(bitmap,28, 28);//需要将图片缩放带28*28
        //将数据feed给tensorflow的输入节点
        inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
        //运行tensorflow
        String[] outputNames = new String[] {outputName};
        inferenceInterface.run(outputNames);
        ///获取输出节点的输出信息
        int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据
        inferenceInterface.fetch(outputName, outputs);
        return outputs;
    }
 
    /**
     * 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
     * @param bitmap 输入被测试的bitmap图片
     * @param rx 将图片缩放到指定的大小(列)->28
     * @param ry 将图片缩放到指定的大小(行)->28
     * @return   返回归一化后的一维float数组 ->28*28
     */
    public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){
        int height = bitmap.getHeight();
        int width = bitmap.getWidth();
        // 计算缩放比例
        float scaleWidth = ((float) rx) / width;
        float scaleHeight = ((float) ry) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
        Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());
        Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());
        height = bitmap.getHeight();
        width = bitmap.getWidth();
        float[] result = new float[height*width];
        int k = 0;
        //行优先
        for(int j = 0;j < height;j++){
            for (int i = 0;i < width;i++){
                int argb = bitmap.getPixel(i,j);
                int r = Color.red(argb);
                int g = Color.green(argb);
                int b = Color.blue(argb);
                int a = Color.alpha(argb);
                //由于是灰度图,所以r,g,b分量是相等的。
                assert(r==g && g==b);
//                Log.i(TAG,i+","+j+" : argb = "+argb+", a="+a+", r="+r+", g="+g+", b="+b);
                result[k++] = r / 255.0f;
            }
        }
        return result;
    }
}
  • 简单说明一下:项目新建了一个PredictionTF类,该类会先加载libtensorflow_inference.so库文件;PredictionTF(AssetManager assetManager, String modePath) 构造方法需要传入AssetManager对象和pb文件的路径;
  • 从资源文件中获取BitMap图片,并传入 getPredict(Bitmap bitmap)方法,该方法首先将BitMap图像缩放到2828的大小,由于原图是灰度图,我们需要获取灰度图的像素值,并将2828的像素转存为行向量的一个float数组,并且每个像素点都归一化到0~1之间,这个就是bitmapToFloatArray(Bitmap bitmap, int rx, int ry)方法的作用;
  • 然后将数据feed给tensorflow的输入节点,并运行(run)tensorflow,最后获取(fetch)输出节点的输出信息。

MainActivity很简单,一个单击事件获取预测结果:

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;

public class MainActivity extends AppCompatActivity {

    // Used to load the 'native-lib' library on application startup.
    static {
        System.loadLibrary("native-lib");//可以去掉
    }

    private static final String TAG = "MainActivity";
    private static final String MODEL_FILE = "file:///android_asset/mnist.pb"; //模型存放路径
    TextView txt;
    TextView tv;
    ImageView imageView;
    Bitmap bitmap;
    PredictionTF preTF;
    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        // Example of a call to a native method
        tv = (TextView) findViewById(R.id.sample_text);
        txt=(TextView)findViewById(R.id.txt_id);
        imageView =(ImageView)findViewById(R.id.imageView1);
        bitmap = BitmapFactory.decodeStream(getClass().getResourceAsStream("/res/drawable/test.bmp"));
        imageView.setImageBitmap(bitmap);
        preTF = new PredictionTF(getAssets(),MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型
    }

    public void click01(View v){
        String res="预测结果为:";
        int[] result= preTF.getPredict(bitmap);
        for (int i=0;i<result.length;i++){
            Log.i(TAG, res+result[i] );
            res=res+String.valueOf(result[i])+" ";
        }
        txt.setText(res);
    }
    /**
     * A native method that is implemented by the 'native-lib' native library,
     * which is packaged with this application.
     */
}

activity_main布局文件:

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    android:paddingBottom="16dp"
    android:paddingLeft="16dp"
    android:paddingRight="16dp"
    android:paddingTop="16dp">
    <TextView
        android:id="@+id/sample_text"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:text="https://blog.csdn.net/guyuealian"
        android:layout_gravity="center"/>
    <Button
        android:onClick="click01"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:text="click" />
    <TextView
        android:id="@+id/txt_id"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:gravity="center"
        android:text="结果为:"/>
    <ImageView
        android:id="@+id/imageView1"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_gravity="center"/>
</LinearLayout>

实际效果
在这里插入图片描述

注意事项

  • 不同的神经网络对于输入图片的要求也不一样,有些需要转化成灰度图,有些需要归一化,有些要对尺寸进行裁剪,有些要减去均值。最好是先写一个Python脚本,将你的模型跑起来,完成需要的预处理,检测神经网络的输出是否正确。通过这个脚本了解输入输出数据的维度,再根据这个Python脚本去改写出对应的java代码。

  • 特别要注意的是:

inferenceInterface.feed(inputName, inputdata, (图片数量,图片长,图片宽,图片通道数)); // inputdata是一个(图片数量*图片长*图片宽*图片通道数)维的float32数组

写好inferenceInterface.feed,才能成功起调模型,得到你想要的结果。

本博客转载于这位博主的博客

转载链接

演示Demo

TensorflowLiteDemo

评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值