引言:
2019年,深度学习越来越火了,各行各业都用上了神经网络。我也想加入AI工程师的行列。眼下有需求要做一个安卓平台的App,需要用到数据分类技术,既然各大厂商争相推出自己的AI推理框架,小编也义不容辞且光明大的抢一碗饭吃。我想这篇文章来的不算晚。目前关于使用Paddle Lite做移动端推理的,除了百度官方给了一个Demo,很少见其他的开源项目了。今天,就把我的这篇初次PaddleLite实战的经历分享出来。
其实想要入门Paddle Lite,一手的资料还是百度在GitHub上开源的Demo,可以说直接下载下来就能跑,现成的工程,神马配置都不缺真的是很好上手。本文第一部分先将使用这个Demo的流程,不管三七二十一,先把Demo跑起来。后面再细说如何变通的把这个项目变成自己的项目,并用于其他功能
一、百度Paddle Lite之物体识别Demo
1.1下载Demo工程
Linux平台:
git clone https://github.com/PaddlePaddle/Paddle-Lite-Demo.git
Windows子系统:
Cd /mnt/d/desktop
git clone https://github.com/PaddlePaddle/Paddle-Lite-Demo.git
你可以根据自己的文件路径选定参数,我的命令是先进入d盘的桌面(大小写不敏感),然后把工程git到本地文件夹
也可以直接在GitHub相应网址直接点击下载,下载好压缩包解压即可
如下是相应的GitHub网址资源
Paddle-Lite-Demo:
PaddlePaddle/Paddle-Lite-Demogithub.comPaddle-Lite:
PaddlePaddle/Paddle-Litegithub.com1.2环境配置
这里需要验证你的软件配置,由于百度给的Demo使用的Android Studio版本较高,请在Android Studio官网下载最新版的软件,3.4+的都OK
则开启征程之前,确认已经安装了:
Java8
Android Studio 3.4+
有版本较新的Android SDK
有一款Android手机
你说你没有,对不起请不要继续看了
你说你是苹果用户,不好意思,我不是
1.3编译安装
连接好手机和电脑,确认你的手机之前能和Android Studio正常配合吧,有些机型确实需要操作一波才行。我之前的旧版本的手机设置好了开发者选项的USB调试就行了,简单好ROOT,换了新手机后结果人家系统做的很复杂,当然这是出于安全考虑吧,我的小米青春8需要账号允许,设置USB调试,还总是出BUG,最后还是关闭了MIUI优化才完美解决。
说编译不太合适,不过就是这个意思。点击Build-->Run或者点击工具栏的绿色三角按钮(称之为运行按钮),接下来就是等待了。漫长的下载和安装会配置好一切需要的组件,之后就能在手机上跑Demo了。
这个Demo写的还很不错,可以切换CPU和TPU,还能切换三种不同的MobileNet模型,还能访问本机图片资源,还能用相机拍照,涉及了诸多Android开发技术。也可以说是个学习安卓的好教材。
二、构建自己的工程
官方的Demo运行起来了,我想把Paddle Lite迁移到自己的项目上,总不能总是在给出的Demo上改动吧。何况我还有使用自己专用的包名类名的洁癖哪!
下面讲述如何动手从头开始实现类似功能的工程,就不全部实现其功能了,这里仅仅演示如何让自己的工程跑起来,能做图像分类推理。
这里也承接了上一篇文章
注意选择Android版本,要与自己的手机兼容,我的MI Lite 8 是Android8.1版本,最低版本不太高就行,这里选择5.0的版本。
2.2 添加Paddle Lite库文件
把Android视图切换到Project目录,打开对应的app目录,在libs文件夹下(没有就新建一个)添加PaddlePredictor.jar,这个文件就从对应的官方Demo中拷贝过来就行。之后再Android Studio目录中看到自动加入了这个文件,就右键-->Add as Library。选项有点多,不要太眼花缭乱。
进一步地,点击OK
在如下路径下添加assets和jniLibs两个文件夹,assets里放置fluid模型和程序需要调用的图片、文档资源;jniLibs里放Paddle Lite的编译库。
看看我放了神魔东西?MobileNetV3是我从百度Demo里复制过来的,ResNet18是我的系列教程第一篇里(也就是上一篇文章)从pytorch模型编译出来的预测模型,你可以先只用一个官方给的,好上手,后面再改动也很简单的。
jniLibs里放的是armV7和armV8的编译库,粘贴过来即可。
2.3 新建Activity
和官方Demo一样的,建立两个Activity,也可以自行选择,按需求而定吧!这里一个是主Activity,另一个是专用图像分类的,这样后续还可以添加更多功能,每个功能都在主程序中提供入口,点击跳转至各自的功能Activity,就和官方Demo一样。
把activity_main.xml的布局改成LinearLayout,适当修改错误就可已成功运行了.
我这里把整体界面布局改成LinearLayout。可以修改官方的界面文件,我这里删除了好多组件,因为只是第一次做,只想着实现功能就好了,以后可以美化一些。
2.4 编写程序
只要之前按照流程添加了PaddlePredictor文件作为库使用,build.gradle就不用了修改了,软件已经自动生成了implementation files('libs/PaddlePredictor.jar')这一行代码,如果你的没有,请加上,并sync一下。注意有两个build.gradle,不要乱改。
下面就只写Java代码,改改UI布局就可以了。
看两个UI文件,从文字中解脱一小灰灰!
主界面UI:activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:orientation="vertical"
android:layout_width="fill_parent"
android:layout_height="fill_parent"
tools:context=".MainActivity">
<ScrollView
android:layout_width="fill_parent"
android:layout_height="fill_parent"
android:fadingEdge="vertical"
android:scrollbars="vertical">
<LinearLayout
android:layout_width="fill_parent"
android:layout_height="fill_parent"
android:orientation="vertical">
<LinearLayout
android:layout_width="fill_parent"
android:layout_height="200dp"
android:orientation="horizontal">
<RelativeLayout
android:id="@+id/v_img_classify"
android:layout_width="wrap_content"
android:layout_height="fill_parent"
android:layout_weight="1"
android:clickable="true"
android:onClick="onClick">
<ImageView
android:id="@+id/iv_img_classify_icon"
android:layout_width="100dp"
android:layout_height="100dp"
android:layout_centerHorizontal="true"
android:layout_centerVertical="true"
android:layout_margin="12dp"
android:adjustViewBounds="true"
android:src="@drawable/image_classfication"
android:scaleType="fitCenter"/>
<TextView
android:id="@+id/iv_img_classify_title"
android:layout_below="@id/iv_img_classify_icon"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_centerHorizontal="true"
android:layout_margin="8dp"
android:text="Image Classification"
android:textStyle="bold"
android:textAllCaps="false"
android:singleLine="false"/>
</RelativeLayout>
</LinearLayout>
</LinearLayout>
</ScrollView>
</LinearLayout>
图像分类界面UI:activity_classify.xml
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:orientation="vertical"
android:layout_width="match_parent"
android:layout_height="match_parent">
<LinearLayout
android:orientation="horizontal"
android:layout_weight="4"
android:layout_width="match_parent"
android:layout_height="match_parent">
<TextView
android:id="@+id/txtInfo"
android:layout_weight="3"
android:text="model"
android:textSize="8pt"
android:layout_width="match_parent"
android:layout_height="match_parent" />
<Switch
android:layout_weight="4"
android:id="@+id/switch_model"
android:textOff="Off"
android:textOn="On"
android:layout_width="match_parent"
android:layout_height="match_parent" />
</LinearLayout>
<ImageView
android:id="@+id/iv_image"
android:layout_weight="2"
android:layout_width="match_parent"
android:layout_height="match_parent"/>
<TextView
android:id="@+id/tv_result"
android:layout_weight="3"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:textStyle="bold"
android:textSize="12pt"
android:text="推理结果"/>
</LinearLayout>
至于Java程序,我这里有这几个模块,叫模块显得专业,好像学过软件工程的感觉,嗯我学过!
CommonActivity是用于图像分类分类Activity(ImgClassifyActivity.java)的父类;
Predictor是调用fluid模型的基本类;
ImgClassifyPredictor继承之,并定义输入输出大小等细节程序;
然后就是主程序Activity.java和工具类Utils.java;
//CommonActivity.java
public class CommonActivity extends AppCompatActivity {
private static final String TAG = CommonActivity.class.getSimpleName();
public static final int OPEN_GALLERY_REQUEST_CODE = 0;
public static final int TAKE_PHOTO_REQUEST_CODE = 1;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
ActionBar supportActionBar = getSupportActionBar();
if (supportActionBar != null) {
supportActionBar.setDisplayHomeAsUpEnabled(true);
}
}
public void onImageChanged(Bitmap image) {}
@Override
public boolean onCreateOptionsMenu(Menu menu) {
MenuInflater inflater = getMenuInflater();
inflater.inflate(R.menu.menu_action_options, menu);
return true;
}
@Override
public boolean onOptionsItemSelected(MenuItem item) {
switch (item.getItemId()) {
case android.R.id.home:
finish();
break;
case R.id.open_gallery:
if (requestAllPermissions())openGallery();
break;
case R.id.take_photo:
if (requestAllPermissions())takePhoto();
break;
}
return super.onOptionsItemSelected(item);
}
@Override
public void onRequestPermissionsResult(int requestCode, String[] permissions, int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) {
Toast.makeText(this, "Permission Denied", Toast.LENGTH_SHORT).show();
}
}
private boolean requestAllPermissions() {
if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE)
!= PackageManager.PERMISSION_GRANTED || ContextCompat.checkSelfPermission(this,
Manifest.permission.CAMERA)
!= PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE,
Manifest.permission.CAMERA},
0);
return false;
}
return true;
}
private void openGallery() {
Intent intent = new Intent(Intent.ACTION_PICK, null);
intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
startActivityForResult(intent, OPEN_GALLERY_REQUEST_CODE);
}
private void takePhoto() {
Intent takePhotoIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
if (takePhotoIntent.resolveActivity(getPackageManager()) != null) {
startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
}
}
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode == RESULT_OK && data != null) {
switch (requestCode) {
case OPEN_GALLERY_REQUEST_CODE:
try {
ContentResolver resolver = getContentResolver();
Uri originalUri = data.getData();
Bitmap imageData = MediaStore.Images.Media.getBitmap(resolver, originalUri);
String[] proj = {MediaStore.Images.Media.DATA};
Cursor cursor = managedQuery(originalUri, proj, null, null, null);
cursor.moveToFirst();
onImageChanged(imageData);
} catch (IOException e) {
Log.e(TAG, e.toString());
}
break;
case TAKE_PHOTO_REQUEST_CODE:
Bundle extras = data.getExtras();
Bitmap imageData = (Bitmap) extras.get("data");
onImageChanged(imageData);
break;
default:
break;
}
}
}
}
下一个:ImgClassifyActivity.java
public class ImgClassifyActivity extends CommonActivity {
protected TextView tvResult, txtInfo;
protected ImageView ivImageData;
protected String imagePath = "cat.jpg";
protected ImgClassifyPredictor predictor;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_classify);
tvResult = findViewById(R.id.tv_result);
txtInfo = findViewById(R.id.txtInfo);
ivImageData = findViewById(R.id.iv_image);
Switch sw = findViewById(R.id.switch_model);
// load Model
receiver = new Receiver();
sender = new Sender();
sw.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
@Override
public void onCheckedChanged(CompoundButton compoundButton, boolean b) {
if (b) {
predictor = new ImgClassifyPredictor("MobileNetV3");
}else {
predictor = new ImgClassifyPredictor("ResNet18");
}
sender.sendEmptyMessage(REQUEST_LOAD_MODEL);
}
});
sw.setChecked(true);
}
public boolean loadImage() {
try {
if (imagePath.isEmpty()) {
return false;
}
Bitmap imageData = null;
// read test image file from custom path if the first character of mode path is '/', otherwise read test
// image file from assets
if (!imagePath.substring(0, 1).equals("/")) {
InputStream imageStream = getAssets().open(imagePath);
imageData = BitmapFactory.decodeStream(imageStream);
} else {
if (!new File(imagePath).exists()) {
return false;
}
imageData = BitmapFactory.decodeFile(imagePath);
}
if (imageData != null && predictor.isLoaded) {
predictor.setImageData(imageData);
return true;
}
} catch (IOException e) {
Toast.makeText(ImgClassifyActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
e.printStackTrace();
}
return false;
}
@Override
public boolean onPrepareOptionsMenu(Menu menu) {
boolean isLoaded = predictor.isLoaded;
menu.findItem(R.id.open_gallery).setEnabled(isLoaded);
menu.findItem(R.id.take_photo).setEnabled(isLoaded);
return super.onPrepareOptionsMenu(menu);
}
public void outputResult() {
Bitmap imageData = predictor.imageData();
if (imageData != null) {
ivImageData.setImageBitmap(imageData);
}
tvResult.setText(
predictor.top1Result
+ "n" + predictor.top2Result
+ "n" + predictor.top3Result
+ "nInference time: " + predictor.inferenceTime + " ms");
}
@Override
public void onImageChanged(Bitmap imageData) {
// rerun model if users pick test image from gallery or camera
if (imageData != null && predictor.isLoaded) {
predictor.setImageData(imageData);
sender.sendEmptyMessage(REQUEST_RUN_MODEL);
}
super.onImageChanged(imageData);
}
@Override
protected void onDestroy() {
if (predictor != null) {
predictor.releaseModel();
}
super.onDestroy();
}
//##########################################################################
public static final int REQUEST_LOAD_MODEL = 0;
public static final int REQUEST_RUN_MODEL = 1;
public static final int RESPONSE_LOAD_MODEL_SUCCESS = 0;
public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
public static final int RESPONSE_RUN_MODEL_SUCCESS = 2;
public static final int RESPONSE_RUN_MODEL_FAILED = 3;
private Receiver receiver;
private Sender sender;
class Receiver extends Handler {// receive messages from worker thread
@Override
public void handleMessage(Message msg) {
switch (msg.what) {
case RESPONSE_LOAD_MODEL_SUCCESS:
txtInfo.setText(predictor.modelName+"--load!");
if (loadImage())sender.sendEmptyMessage(REQUEST_RUN_MODEL);
break;
case RESPONSE_LOAD_MODEL_FAILED:
txtInfo.setText(predictor.modelName+"--fail!");
break;
case RESPONSE_RUN_MODEL_SUCCESS:
txtInfo.setText(predictor.modelName+"--run!");
// obtain results and update UI
outputResult();
break;
case RESPONSE_RUN_MODEL_FAILED:
txtInfo.setText(predictor.modelName+"--stop!");
break;
default:
break;
}
}
};
class Sender extends Handler {// send command to worker thread
public void handleMessage(Message msg) {
switch (msg.what) {
case REQUEST_LOAD_MODEL:
// load model and reload test image
if (predictor.init(getApplication())) {
receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_SUCCESS);
} else {
receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_FAILED);
}
break;
case REQUEST_RUN_MODEL:
// run model if model is loaded
if (predictor.isLoaded && predictor.runModel()) {
receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_SUCCESS);
} else {
receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_FAILED);
}
break;
default:
break;
}
}
};
}
下一个:Predictor.java
public class Predictor {
public boolean isLoaded = false;
public float inferenceTime = 0;
public String modelName = "";
protected Context appCtx = null;
protected int whichDevice = 0; // 0: CPU 1: NPU
protected ArrayList<PaddlePredictor> paddlePredictors = new ArrayList<PaddlePredictor>(); // 0: CPU 1: NPU
public Predictor() {
}
public boolean init(Context appCtx, String modelPath) {
this.appCtx = appCtx;
isLoaded = loadModel(modelPath);
return isLoaded;
}
protected boolean loadModel(String modelPath) {
// release model if exists
releaseModel();
// load model
if (modelPath.isEmpty())return false;
String realPath = modelPath;
if (!modelPath.substring(0, 1).equals("/")) {
// read model files from custom path if the first character of mode path is '/'
// otherwise copy model to cache from assets
realPath = appCtx.getCacheDir() + "/" + modelPath;
Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath);
}
if (realPath.isEmpty())return false;
modelName = realPath.substring(realPath.lastIndexOf("/") + 1);
// run on CPU
CxxConfig config = new CxxConfig();
config.setModelDir(realPath);
Place preferredPlace = new Place(Place.TargetType.ARM, Place.PrecisionType.FLOAT);
Place[] validPlaces = new Place[2];
validPlaces[0] = new Place(Place.TargetType.HOST, Place.PrecisionType.FLOAT);
validPlaces[1] = new Place(Place.TargetType.ARM, Place.PrecisionType.FLOAT);
config.setPreferredPlace(preferredPlace);
config.setValidPlaces(validPlaces);
paddlePredictors.add(PaddlePredictor.createPaddlePredictor(config));
return true;
}
public void releaseModel() {
paddlePredictors.clear();
isLoaded = false;
modelName = "";
}
public Tensor getInput(int idx) {
if (paddlePredictors.size() < whichDevice + 1) {
return null;
}
return paddlePredictors.get(whichDevice).getInput(idx);
}
public Tensor getOutput(int idx) {
if (paddlePredictors.size() < whichDevice + 1) {
return null;
}
return paddlePredictors.get(whichDevice).getOutput(idx);
}
public boolean runModel() {
if (paddlePredictors.size() < whichDevice + 1) {
return false;
}
// warm up
paddlePredictors.get(whichDevice).run();
// inference
Date start = new Date();
paddlePredictors.get(whichDevice).run();
Date end = new Date();
inferenceTime = (end.getTime() - start.getTime());
return true;
}
}
之后:ImgClassifyPredictor.java
public class ImgClassifyPredictor extends Predictor {
private static final String TAG = ImgClassifyPredictor.class.getSimpleName();
protected Vector<String> wordLabels = new Vector<String>();
protected Bitmap imageData = null;
protected String top1Result = "";
protected String top2Result = "";
protected String top3Result = "";
protected float preprocessTime = 0;
protected float postprocessTime = 0;
// model config
protected String modelPath = "MobileNetV3";
protected String labelPath = "synset_words.txt";
protected long[] inputShape = new long[]{1,3,224,224};
protected float[] inputMean = new float[]{0.485f,0.456f,0.406f};
protected float[] inputStd = new float[]{0.229f,0.224f,0.225f};
public ImgClassifyPredictor(String modelPath) {
super();
this.modelPath = modelPath;
}
public boolean init(Context appCtx) {
if (inputShape.length != 4) {
Log.i(TAG, "size of input shape should be: 4");
return false;
}
if (inputMean.length != inputShape[1]) {
Log.i(TAG, "size of input mean should be: " + Long.toString(inputShape[1]));
return false;
}
if (inputStd.length != inputShape[1]) {
Log.i(TAG, "size of input std should be: " + Long.toString(inputShape[1]));
return false;
}
if (inputShape[0] != 1) {
Log.i(TAG, "only one batch is supported in the image classification demo, you can use any batch size in " +
"your Apps!");
return false;
}
if (inputShape[1] != 1 && inputShape[1] != 3) {
Log.i(TAG, "only one/three channels are supported in the image classification demo, you can use any " +
"channel size in your Apps!");
return false;
}
super.init(appCtx, modelPath);
if (!super.isLoaded) {
return false;
}
isLoaded &= loadLabel(labelPath);
this.inputShape = inputShape;
this.inputMean = inputMean;
this.inputStd = inputStd;
return isLoaded;
}
protected boolean loadLabel(String labelPath) {
wordLabels.clear();
// load word labels from file
try {
InputStream assetsInputStream = appCtx.getAssets().open(labelPath);
int available = assetsInputStream.available();
byte[] lines = new byte[available];
assetsInputStream.read(lines);
assetsInputStream.close();
String words = new String(lines);
String[] contents = words.split("n");
for (String content : contents) {
int first_space_pos = content.indexOf(" ");
if (first_space_pos >= 0 && first_space_pos < content.length()) {
wordLabels.add(content.substring(first_space_pos));
}
}
Log.i(TAG, "word label size: " + wordLabels.size());
} catch (Exception e) {
Log.e(TAG, e.getMessage());
return false;
}
return true;
}
public Tensor getInput(int idx) {
return super.getInput(idx);
}
public Tensor getOutput(int idx) {
return super.getOutput(idx);
}
public boolean runModel(Bitmap imageData) {
setImageData(imageData);
return runModel();
}
public boolean runModel() {
if (imageData == null) {
return false;
}
// set input shape
Tensor inputTensor = getInput(0);
inputTensor.resize(inputShape);
// pre-process image, and feed input tensor with pre-processed data
Date start = new Date();
int channels = (int) inputShape[1];
int width = (int) inputShape[3];
int height = (int) inputShape[2];
float[] inputData = new float[channels * width * height];
for (int i = 0; i < height; i++) {
for (int j = 0; j < width; j++) {
int color = imageData.getPixel(j, i);
float r = (float) red(color) / 255.0f;
float g = (float) green(color) / 255.0f;
float b = (float) blue(color) / 255.0f;
if (channels == 3) {
r = r - inputMean[0];
g = g - inputMean[1];
b = b - inputMean[2];
r = r / inputStd[0];
g = g / inputStd[1];
b = b / inputStd[2];
int rIdx = i * width + j;
int gIdx = rIdx + width * height;
int bIdx = gIdx + width * height;
inputData[rIdx] = r;
inputData[gIdx] = g;
inputData[bIdx] = b;
} else { // channels = 1
float gray = (b + g + r) / 3.0f;
gray = gray - inputMean[0];
gray = gray / inputStd[0];
inputData[i * width + j] = gray;
}
}
}
inputTensor.setData(inputData);
Date end = new Date();
preprocessTime = (float) (end.getTime() - start.getTime());
// inference
super.runModel();
// fetch output tensor
Tensor outputTensor = getOutput(0);
// post-process
start = new Date();
long outputShape[] = outputTensor.shape();
long outputSize = 1;
for (long s : outputShape) {
outputSize *= s;
}
int[] max_index = new int[3]; // top3 indices
double[] max_num = new double[3]; // top3 scores
for (int i = 0; i < outputSize; i++) {
float tmp = outputTensor.getFloatData()[i];
int tmp_index = i;
for (int j = 0; j < 3; j++) {
if (tmp > max_num[j]) {
tmp_index += max_index[j];
max_index[j] = tmp_index - max_index[j];
tmp_index -= max_index[j];
tmp += max_num[j];
max_num[j] = tmp - max_num[j];
tmp -= max_num[j];
}
}
}
end = new Date();
postprocessTime = (float) (end.getTime() - start.getTime());
if (wordLabels.size() > 0) {
top1Result = "Top1: " + wordLabels.get(max_index[0]) + " - " + String.format("%.4f", max_num[0]);
top2Result = "Top2: " + wordLabels.get(max_index[1]) + " - " + String.format("%.4f", max_num[1]);
top3Result = "Top3: " + wordLabels.get(max_index[2]) + " - " + String.format("%.4f", max_num[2]);
}
return true;
}
public Bitmap imageData() {
return imageData;
}
public void setImageData(Bitmap imageData) {
if (imageData == null) {
return;
}
// scale image to the size of input tensor
Bitmap rgbaData = imageData.copy(Bitmap.Config.ARGB_8888, true);
Bitmap scaleData = Bitmap.createScaledBitmap(rgbaData, (int) inputShape[3], (int) inputShape[2], true);
this.imageData = scaleData;
}
}
最后的主程序:MainActivity.java
public class MainActivity extends AppCompatActivity implements View.OnClickListener {
private static final String TAG = MainActivity.class.getSimpleName();
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
findViewById(R.id.v_img_classify).setOnClickListener(this);
Intent intent = new Intent(MainActivity.this, ImgClassifyActivity.class);
startActivity(intent);
}
@Override
public void onClick(View v) {
switch (v.getId()) {
case R.id.v_img_classify:
Intent intent = new Intent(MainActivity.this, ImgClassifyActivity.class);
startActivity(intent);
break;
}
}
@Override
protected void onDestroy() {
super.onDestroy();
System.exit(0);
}
}
和工具类:Utils.java
public class Utils {
public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
if (srcPath.isEmpty() || dstPath.isEmpty()) {
return;
}
InputStream is = null;
OutputStream os = null;
try {
is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
os = new BufferedOutputStream(new FileOutputStream(new File(dstPath)));
byte[] buffer = new byte[1024];
int length = 0;
while ((length = is.read(buffer)) != -1) {
os.write(buffer, 0, length);
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
os.close();
is.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
if (srcDir.isEmpty() || dstDir.isEmpty()) {
return;
}
try {
if (!new File(dstDir).exists()) {
new File(dstDir).mkdirs();
}
for (String fileName : appCtx.getAssets().list(srcDir)) {
String srcSubPath = srcDir + File.separator + fileName;
String dstSubPath = dstDir + File.separator + fileName;
if (new File(srcSubPath).isDirectory()) {
copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
} else {
copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
以上程序都没有包括包名和包导入部分,请自行设置和导入。
最后,让我放肆地贴几张运行截图吧,上一篇文章中你们也看到了一些的。