将Keras模型迁移到android端(android studio)

环境

  • win8
  • python 3.5.4
  • TensorFlow 1.7.0
  • Android Studio 3.1.1

把模型部署到安卓设备上总体的步骤如下:

  • 将训练好的模型转换成 TensorFlow 格式
  • 向安卓应用添加 TensorFlow Mobile 依赖项
  • 编写相关的 Java 代码,在你的应用中使用 TensorFlow 模型执行推断

一、将训练好的模型转换成 TensorFlow 格式
你可以从这里下载预先训练的Keras Squeezenet model
squeezenet_weights_tf_dim_ordering_tf_kernels.h5

新建python脚本文件:(本代码实现了keras的h5模型转换到tensorflow的pd模型格式,对应着keras_to_tensorflow的函数)

from keras.models import Model
from keras.layers import *
import os
import tensorflow as tf


def keras_to_tensorflow(keras_model, output_dir, model_name,out_prefix="output_", log_tensorboard=True):

    if os.path.exists(output_dir) == False:
        os.mkdir(output_dir)

    out_nodes = []

    for i in range(len(keras_model.outputs)):
        out_nodes.append(out_prefix + str(i + 1))
        tf.identity(keras_model.output[i], out_prefix + str(i + 1))

    sess = K.get_session()

    from tensorflow.python.framework import graph_util, graph_io

    init_graph = sess.graph.as_graph_def()

    main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

    graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)

    if log_tensorboard:
        from tensorflow.python.tools import import_pb_to_tensorboard

        import_pb_to_tensorboard.import_to_tensorboard(
            os.path.join(output_dir, model_name),
            output_dir)


"""
We explicitly redefine the Squeezent architecture since Keras has no predefined Squeezenet
"""

def squeezenet_fire_module(input, input_channel_small=16, input_channel_large=64):

    channel_axis = 3

    input = Conv2D(input_channel_small, (1,1), padding="valid" )(input)
    input = Activation("relu")(input)

    input_branch_1 = Conv2D(input_channel_large, (1,1), padding="valid" )(input)
    input_branch_1 = Activation("relu")(input_branch_1)

    input_branch_2 = Conv2D(input_channel_large, (3, 3), padding="same")(input)
    input_branch_2 = Activation("relu")(input_branch_2)

    input = concatenate([input_branch_1, input_branch_2], axis=channel_axis)

    return input


def SqueezeNet(input_shape=(224,224,3)):



    image_input = Input(shape=input_shape)


    network = Conv2D(64, (3,3), strides=(2,2), padding="valid")(image_input)
    network = Activation("relu")(network)
    network = MaxPool2D( pool_size=(3,3) , strides=(2,2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
    network = MaxPool2D(pool_size=(3,3), strides=(2,2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
    network = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(network)

    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
    network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)

    #Remove layers like Dropout and BatchNormalization, they are only needed in training
    #network = Dropout(0.5)(network)

    network = Conv2D(1000, kernel_size=(1,1), padding="valid", name="last_conv")(network)
    network = Activation("relu")(network)

    network = GlobalAvgPool2D()(network)
    network = Activation("softmax",name="output")(network)


    input_image = image_input
    model = Model(inputs=input_image, outputs=network)

    return model


keras_model = SqueezeNet()

keras_model.load_weights("squeezenet_weights_tf_dim_ordering_tf_kernels.h5")


output_dir = os.path.join(os.getcwd(),"checkpoint")

keras_to_tensorflow(keras_model,output_dir="E:/PycharmProjects/Keras2TensorFlow/test2tensorflow",model_name="squeezenet_test.pb")

print("MODEL SAVED")

二、配置AndroidStudio依赖:
请在 Android Studio 中创建一个新的工程。在你的 app:build.gradle 文件中添加 TensorFlow Mobile 依赖

implementation 'org.tensorflow:tensorflow-android:+'

三、android端代码编写:
此时环境已经配好,只需要书写Java代码。这里只使用了最简单的Button、TextView、ImageVIew控件。
1、在编写代码进行实际推断之前,你需要将转换后的模型(squeezenet_test.pb)添加到应用程序的资源文件夹中。在 Android Studio 中,右键点击你的项目,跳转至「Add Folder」(添加文件夹)部分,并选择「Assets Folder」(资源文件夹)。这将在你的应用程序目录中创建一个资源文件夹。接下来,你需要将模型复制到资源文件夹中。如下:
在这里插入图片描述
其中squeezenet_test.pb为tensorflow的模型文件,labels.json为模型输出数值后对应的label具体含义。you can get labels.json from here
2、将一个新的 Java 类添加到项目的主程序包中,并将其命名为 ImageUtils,把下面的代码复制到其中。

package com.example.doremi.testkeras2tensorflow;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Matrix;
import android.os.Environment;
import java.io.File;
import java.io.FileOutputStream;
import java.io.InputStream;
import org.json.*;

/**
 * Utility class for manipulating images.
 **/
public class ImageUtils {
    /**
     * Returns a transformation matrix from one reference frame into another.
     * Handles cropping (if maintaining aspect ratio is desired) and rotation.
     *
     * @param srcWidth Width of source frame.
     * @param srcHeight Height of source frame.
     * @param dstWidth Width of destination frame.
     * @param dstHeight Height of destination frame.
     * @param applyRotation Amount of rotation to apply from one frame to another.
     *  Must be a multiple of 90.
     * @param maintainAspectRatio If true, will ensure that scaling in x and y remains constant,
     * cropping the image if necessary.
     * @return The transformation fulfilling the desired requirements.
     */
    public static Matrix getTransformationMatrix(
            final int srcWidth,
            final int srcHeight,
            final int dstWidth,
            final int dstHeight,
            final int applyRotation,
            final boolean maintainAspectRatio) {
        final Matrix matrix = new Matrix();

        if (applyRotation != 0) {
            // Translate so center of image is at origin.
            matrix.postTranslate(-srcWidth / 2.0f, -srcHeight / 2.0f);

            // Rotate around origin.
            matrix.postRotate(applyRotation);
        }

        // Account for the already applied rotation, if any, and then determine how
        // much scaling is needed for each axis.
        final boolean transpose = (Math.abs(applyRotation) + 90) % 180 == 0;

        final int inWidth = transpose ? srcHeight : srcWidth;
        final int inHeight = transpose ? srcWidth : srcHeight;

        // Apply scaling if necessary.
        if (inWidth != dstWidth || inHeight != dstHeight) {
            final float scaleFactorX = dstWidth / (float) inWidth;
            final float scaleFactorY = dstHeight / (float) inHeight;

            if (maintainAspectRatio) {
                // Scale by minimum factor so that dst is filled completely while
                // maintaining the aspect ratio. Some image may fall off the edge.
                final float scaleFactor = Math.max(scaleFactorX, scaleFactorY);
                matrix.postScale(scaleFactor, scaleFactor);
            } else {
                // Scale exactly to fill dst from src.
                matrix.postScale(scaleFactorX, scaleFactorY);
            }
        }

        if (applyRotation != 0) {
            // Translate back from origin centered reference to destination frame.
            matrix.postTranslate(dstWidth / 2.0f, dstHeight / 2.0f);
        }

        return matrix;
    }


    public static Bitmap processBitmap(Bitmap source,int size){

        int image_height = source.getHeight();
        int image_width = source.getWidth();

        Bitmap croppedBitmap = Bitmap.createBitmap(size, size, Bitmap.Config.ARGB_8888);

        Matrix frameToCropTransformations = getTransformationMatrix(image_width,image_height,size,size,0,false);
        Matrix cropToFrameTransformations = new Matrix();
        frameToCropTransformations.invert(cropToFrameTransformations);

        final Canvas canvas = new Canvas(croppedBitmap);
        canvas.drawBitmap(source, frameToCropTransformations, null);

        return croppedBitmap;


    }

    public static float[] normalizeBitmap(Bitmap source,int size,float mean,float std){

        float[] output = new float[size * size * 3];

        int[] intValues = new int[source.getHeight() * source.getWidth()];

        source.getPixels(intValues, 0, source.getWidth(), 0, 0, source.getWidth(), source.getHeight());
        for (int i = 0; i < intValues.length; ++i) {
            final int val = intValues[i];
            output[i * 3] = (((val >> 16) & 0xFF) - mean)/std;
            output[i * 3 + 1] = (((val >> 8) & 0xFF) - mean)/std;
            output[i * 3 + 2] = ((val & 0xFF) - mean)/std;
        }

        return output;

    }

    public static Object[] argmax(float[] array){


        int best = -1;
        float best_confidence = 0.0f;

        for(int i = 0;i < array.length;i++){

            float value = array[i];

            if (value > best_confidence){

                best_confidence = value;
                best = i;
            }
        }



        return new Object[]{best,best_confidence};


    }


    public static String getLabel( InputStream jsonStream,int index){
        String label = "";
        try {

            byte[] jsonData = new byte[jsonStream.available()];
            jsonStream.read(jsonData);
            jsonStream.close();

            String jsonString = new String(jsonData,"utf-8");

            JSONObject object = new JSONObject(jsonString);

            label = object.getString(String.valueOf(index));



        }
        catch (Exception e){


        }
        return label;
    }
}


假如只是用来开发的话对于ImageUtils这个类不需要理解代码实现,会用就好啦。
2、在你的主活动(main activity)添加代码。它们将被用于显示图像和预测结果。

package com.example.doremi.testkeras2tensorflow;

import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;


import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.AsyncTask;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;

import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

import java.io.InputStream;

public class MainActivity extends AppCompatActivity {
    /*
     * 在需要调用TensoFlow的地方,加载so库“System.loadLibrary("tensorflow_inference");
     * 并”import org.tensorflow.contrib.android.TensorFlowInferenceInterface;就可以使用了
     * */
    //Load the tensorflow inference library
    //static{}(即static块),会在类被加载的时候执行且仅会被执行一次,一般用来初始化静态变量和调用静态方法。
    static {
        System.loadLibrary("tensorflow_inference");
        Log.i("wumei","load tensorflow_inference successfully");
    }

    //PATH TO OUR MODEL FILE AND NAMES OF THE INPUT AND OUTPUT NODES
    //各节点名称
    private String MODEL_PATH = "file:///android_asset/squeezenet_test.pb";
    private String INPUT_NAME = "input_1";
    private String OUTPUT_NAME = "output_1";
    private TensorFlowInferenceInterface tf;

    //ARRAY TO HOLD THE PREDICTIONS AND FLOAT VALUES TO HOLD THE IMAGE DATA
    //保存图片和图片尺寸的
    float[] PREDICTIONS = new float[1000];
    private float[] floatValues;
    private int[] INPUT_SIZE = {224,224,3};

    ImageView imageView;
    TextView resultView;
    Button buttonSub;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        tf = new TensorFlowInferenceInterface(getAssets(),MODEL_PATH);

        imageView=(ImageView)findViewById(R.id.imageView1);
        resultView=(TextView)findViewById(R.id.text_show);
        buttonSub=(Button)findViewById(R.id.button1);

        buttonSub.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                try{
                    //READ THE IMAGE FROM ASSETS FOLDER
                    InputStream imageStream = getAssets().open("testimage.png"); //
                    Log.d("wumei",imageStream.toString());
                    Bitmap bitmap = BitmapFactory.decodeStream(imageStream);
                    imageView.setImageBitmap(bitmap);

                    predict(bitmap);

                }catch(Exception e){

                }
            }
        });

    }

    //FUNCTION TO COMPUTE THE MAXIMUM PREDICTION AND ITS CONFIDENCE
    public Object[] argmax(float[] array){

        int best = -1;
        float best_confidence = 0.0f;
        for(int i = 0;i < array.length;i++){
            float value = array[i];
            if (value > best_confidence){
                best_confidence = value;
                best = i;
            }
        }
        return new Object[]{best,best_confidence};
    }



    public void predict(final Bitmap bitmap){

        //Runs inference in background thread
        new AsyncTask<Integer,Integer,Integer>(){

            @Override
            protected Integer doInBackground(Integer ...params){
                //Resize the image into 224 x 224
                Bitmap resized_image = ImageUtils.processBitmap(bitmap,224);

                //Normalize the pixels
                floatValues = ImageUtils.normalizeBitmap(resized_image,224,127.5f,1.0f);

                //Pass input into the tensorflow
                tf.feed(INPUT_NAME,floatValues,1,224,224,3);

                //compute predictions
                tf.run(new String[]{OUTPUT_NAME});

                //copy the output into the PREDICTIONS array
                tf.fetch(OUTPUT_NAME,PREDICTIONS);

                //Obtained highest prediction
                Object[] results = argmax(PREDICTIONS);

                int class_index = (Integer) results[0];
                float confidence = (Float) results[1];

                try{
                    final String conf = String.valueOf(confidence * 100).substring(0,5);
                    //Convert predicted class index into actual label name
                    final String label = ImageUtils.getLabel(getAssets().open("labels.json"),class_index);
                    //Display result on UI
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {
                            resultView.setText(label + " : " + conf + "%"); //这里控制textview显示当前的结果值
                        }
                    });
                } catch (Exception e){
                }

                return 0;
            }

        }.execute(0);

    }
}


其中模型的推理部分放入到了predic函数中,并且将其耗时操作加入到了子线程中。
4.需要用到的布局文件(虽然对于好多人来说,这一步是多余的)

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
    android:orientation="vertical"
    android:layout_width="match_parent"
    android:layout_height="wrap_content">

    <Button
        android:id="@+id/button1"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:text="predict" />

    <TextView
        android:id="@+id/text_show"
        android:layout_width="wrap_content"
        android:layout_height="52dp"
        android:layout_gravity="start"
        android:text=" you can get prediction result from here" />

    <ImageView
        android:id="@+id/imageView1"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:background="@drawable/testimage"
        android:layout_gravity="center_horizontal"/>

</LinearLayout>

布局预览:
preview
5.you can click the run!!!!
下面是最后的效果图
result

reference
1
2
3
4

附:
windows查看python版本号:python --version
windows查看TensorFlow版本:import tensorflow as tf tf.__version__
查询tensorflow安装路径为:tf.__path__

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值