如何将深度学习模型加载到android环境中

承接上一篇的内容,考虑如何将深度学习的模型加载到android app中


前言

将图片学习的模型加载到手机应用中,让学习模型发挥实际用处。这里以SqueezeNet模型(典型轻量化模型)为参考,记录如何将轻量化模型应用到手机app上。

一、使用工具

采用pycharm和android studio分别作为python和android开发的IDE。

二、使用步骤

1.模型格式的转换

Keras可以作为初学者入手深度学习模型不错的模型库,其语言简单且易于理解。Keras一般将训练好的模型和参数以h5格式保存,这里需要转换成tflite格式,以供android调用。

from keras_squeezenet import SqueezeNet
model = SqueezeNet(weights=None)
model.load_weights("./squeezenet_weights_tf_dim_ordering_tf_kernels.h5") ###从网上下载SequeezeNet网络模型及参数

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("./squeezenet_model.tflite", "wb").write(tflite_model)

这里生成的tflite文件能够根据输入图片产生预测结果的概率值。

将训练的模型加载到android上不仅需要tflite模型,还需要label文件。

在github中下载的模型相配套的label文件。

时候不好下载,可以从此处下载:
分类标签文件

接下来通过Android Studio建立项目,该过程这里不详细描述。

2.配置文件修改

在新建的项目中,在build.gradle中增加

android {
   
......

aaptOptions {
   
    noCompress "tflite"
    noCompress "lite"
}
}

dependencies {
   
.......
implementation 'org.tensorflow:tensorflow-lite:2.2.0'
}

大家可以根据自己的要求写界面文件,这里附上写的界面文件(AndroidManifest.xml):

<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
          package="gdut.bsx.tensorflowtraining">

    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
    <uses-permission android:name="android.permission.CAMERA"/>

    <application
            android:allowBackup="true"
            android:icon="@mipmap/ic_launcher"
            android:label="@string/app_name"
            android:roundIcon="@mipmap/ic_launcher_round"
            android:supportsRtl="true"
            android:screenOrientation="portrait"
            android:theme="@style/AppTheme">

        <activity android:name=".MainActivity" android:screenOrientation="portrait"
                  android:launchMode="singleTask">
            <intent-filter>
                <action android:name="android.intent.action.MAIN"/>

                <category android:name="android.intent.category.LAUNCHER"/>
            </intent-filter>
        </activity>

        <provider
                android:name="android.support.v4.content.FileProvider"
                android:authorities="gdut.bsx.tensorflowtraining.fileprovider"
                android:exported="false"
                android:grantUriPermissions="true">
            <meta-data
                    android:name="android.support.FILE_PROVIDER_PATHS"
                    android:resource="@xml/file_paths" />
        </provider>

    </application>

</manifest>

将生成的模型文件 sequeezenet.tflite和Class_Label.txt 放入文件夹assets中。
注意:一定要用分类标签文件。

3. 应用程序

这里我们实现的代码主要功能包括摄像头拍照,图片管理、图像分类。
包含代码包括:
ImageClassifier.java

package gdut.bsx.tensorflowtraining;

import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.os.SystemClock;
import android.util.Log;

import org.tensorflow.lite.Interpreter;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;

public class ImageClassifier {
   

    private static volatile ImageClassifier INSTANCE;

    /**
     * Tag for the {@link Log}.
     */

    private static final String TAG = "TfLiteCameraDemo";
    /**
     * Name of the model file stored in Assets.
     */

    private static final String MODEL_PATH = "squeezenet.tflite";
    //private static final String MODEL_PATH = "mobilenet_v1_0.25_128_quant.tflite";

    /**
     * Name of the label file stored in Assets.
     */

    private static final String LABEL_PATH = "Class_Label.txt";

    /** Number of results to show in the UI. */
    private static final int RESULTS_TO_SHOW = 3;
    /** Dimensions of inputs. */
    private static final int DIM_BATCH_SIZE = 1;
    private static final int DIM_PIXEL_SIZE = 3;

    static final int DIM_IMG_SIZE_X = 224;
    static final int DIM_IMG_SIZE_Y = 224;

    private static final int IMAGE_MEAN = 128;
    private static final float IMAGE_STD = 128.0f;

    /* Preallocated buffers for storing image data in. */
    private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];

    /**
     * An instance of the driver class to run model inference with Tensorflow Lite.
     */

    private Interpreter tflite;

    /**
     * Labels corresponding to the output of the vision model.
     */

    private List<String> labelList;
    /**
     * A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs.
     */

    private ByteBuffer imgData ;
    /**
     * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
     */
    private float[][] labelProbArray ;


    /**
     * multi-stage low pass filter
     **/

    private float[][] filterLabelProbArray = null;
    private static final int FILTER_STAGES = 3;
    private static final float FILTER_FACTOR = 0.4f;


    /**
     * Initializes an {@code ImageClassifier}.
     */


    private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
            new PriorityQueue<>(
                    RESULTS_TO_SHOW,
                    new Comparator<Map.Entry<String, Float>>() {
   
                        @Override
                        public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
   
                            return (o1.getValue()).compareTo(o2.getValue());
                        }
                    });


    private ImageClassifier(Context activity) throws IOException {
   
        tflite = new Interpreter(loadModelFile(activity));
        labelList = loadLabelList(activity);
        imgData = ByteBuffer.allocateDirect(
                        4 * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
        imgData.order(ByteOrder.nativeOrder());
        labelProbArray = new float[1][labelList.size(
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值