承接上一篇的内容,考虑如何将深度学习的模型加载到android app中
前言
将图片学习的模型加载到手机应用中,让学习模型发挥实际用处。这里以SqueezeNet模型(典型轻量化模型)为参考,记录如何将轻量化模型应用到手机app上。
一、使用工具
采用pycharm和android studio分别作为python和android开发的IDE。
二、使用步骤
1.模型格式的转换
Keras可以作为初学者入手深度学习模型不错的模型库,其语言简单且易于理解。Keras一般将训练好的模型和参数以h5格式保存,这里需要转换成tflite格式,以供android调用。
from keras_squeezenet import SqueezeNet
model = SqueezeNet(weights=None)
model.load_weights("./squeezenet_weights_tf_dim_ordering_tf_kernels.h5") ###从网上下载SequeezeNet网络模型及参数
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("./squeezenet_model.tflite", "wb").write(tflite_model)
这里生成的tflite文件能够根据输入图片产生预测结果的概率值。
将训练的模型加载到android上不仅需要tflite模型,还需要label文件。
在github中下载的模型相配套的label文件。
时候不好下载,可以从此处下载:
分类标签文件
接下来通过Android Studio建立项目,该过程这里不详细描述。
2.配置文件修改
在新建的项目中,在build.gradle中增加
android {
......
aaptOptions {
noCompress "tflite"
noCompress "lite"
}
}
dependencies {
.......
implementation 'org.tensorflow:tensorflow-lite:2.2.0'
}
大家可以根据自己的要求写界面文件,这里附上写的界面文件(AndroidManifest.xml):
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="gdut.bsx.tensorflowtraining">
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.CAMERA"/>
<application
android:allowBackup="true"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:screenOrientation="portrait"
android:theme="@style/AppTheme">
<activity android:name=".MainActivity" android:screenOrientation="portrait"
android:launchMode="singleTask">
<intent-filter>
<action android:name="android.intent.action.MAIN"/>
<category android:name="android.intent.category.LAUNCHER"/>
</intent-filter>
</activity>
<provider
android:name="android.support.v4.content.FileProvider"
android:authorities="gdut.bsx.tensorflowtraining.fileprovider"
android:exported="false"
android:grantUriPermissions="true">
<meta-data
android:name="android.support.FILE_PROVIDER_PATHS"
android:resource="@xml/file_paths" />
</provider>
</application>
</manifest>
将生成的模型文件 sequeezenet.tflite和Class_Label.txt 放入文件夹assets中。
注意:一定要用分类标签文件。
3. 应用程序
这里我们实现的代码主要功能包括摄像头拍照,图片管理、图像分类。
包含代码包括:
ImageClassifier.java
package gdut.bsx.tensorflowtraining;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.os.SystemClock;
import android.util.Log;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
public class ImageClassifier {
private static volatile ImageClassifier INSTANCE;
/**
* Tag for the {@link Log}.
*/
private static final String TAG = "TfLiteCameraDemo";
/**
* Name of the model file stored in Assets.
*/
private static final String MODEL_PATH = "squeezenet.tflite";
//private static final String MODEL_PATH = "mobilenet_v1_0.25_128_quant.tflite";
/**
* Name of the label file stored in Assets.
*/
private static final String LABEL_PATH = "Class_Label.txt";
/** Number of results to show in the UI. */
private static final int RESULTS_TO_SHOW = 3;
/** Dimensions of inputs. */
private static final int DIM_BATCH_SIZE = 1;
private static final int DIM_PIXEL_SIZE = 3;
static final int DIM_IMG_SIZE_X = 224;
static final int DIM_IMG_SIZE_Y = 224;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
/* Preallocated buffers for storing image data in. */
private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
/**
* An instance of the driver class to run model inference with Tensorflow Lite.
*/
private Interpreter tflite;
/**
* Labels corresponding to the output of the vision model.
*/
private List<String> labelList;
/**
* A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs.
*/
private ByteBuffer imgData ;
/**
* An array to hold inference results, to be feed into Tensorflow Lite as outputs.
*/
private float[][] labelProbArray ;
/**
* multi-stage low pass filter
**/
private float[][] filterLabelProbArray = null;
private static final int FILTER_STAGES = 3;
private static final float FILTER_FACTOR = 0.4f;
/**
* Initializes an {@code ImageClassifier}.
*/
private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
new PriorityQueue<>(
RESULTS_TO_SHOW,
new Comparator<Map.Entry<String, Float>>() {
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
return (o1.getValue()).compareTo(o2.getValue());
}
});
private ImageClassifier(Context activity) throws IOException {
tflite = new Interpreter(loadModelFile(activity));
labelList = loadLabelList(activity);
imgData = ByteBuffer.allocateDirect(
4 * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
imgData.order(ByteOrder.nativeOrder());
labelProbArray = new float[1][labelList.size(