paddle lite 嵌入式linux_百度Paddle-Lite教程之图像分类App开发实战

引言:

2019年,深度学习越来越火了,各行各业都用上了神经网络。我也想加入AI工程师的行列。眼下有需求要做一个安卓平台的App,需要用到数据分类技术,既然各大厂商争相推出自己的AI推理框架,小编也义不容辞且光明大的抢一碗饭吃。我想这篇文章来的不算晚。目前关于使用Paddle Lite做移动端推理的,除了百度官方给了一个Demo,很少见其他的开源项目了。今天,就把我的这篇初次PaddleLite实战的经历分享出来。

其实想要入门Paddle Lite,一手的资料还是百度在GitHub上开源的Demo,可以说直接下载下来就能跑,现成的工程,神马配置都不缺真的是很好上手。本文第一部分先将使用这个Demo的流程,不管三七二十一,先把Demo跑起来。后面再细说如何变通的把这个项目变成自己的项目,并用于其他功能

一、百度Paddle Lite之物体识别Demo

1.1下载Demo工程

Linux平台:

git clone https://github.com/PaddlePaddle/Paddle-Lite-Demo.git

Windows子系统:

Cd /mnt/d/desktop

git clone https://github.com/PaddlePaddle/Paddle-Lite-Demo.git

你可以根据自己的文件路径选定参数,我的命令是先进入d盘的桌面(大小写不敏感),然后把工程git到本地文件夹

也可以直接在GitHub相应网址直接点击下载,下载好压缩包解压即可

如下是相应的GitHub网址资源

Paddle-Lite-Demo:

PaddlePaddle/Paddle-Lite-Demo​github.com
084077cf9d93a466faf001f5a021dae7.png

Paddle-Lite:

PaddlePaddle/Paddle-Lite​github.com
084077cf9d93a466faf001f5a021dae7.png

1.2环境配置

这里需要验证你的软件配置,由于百度给的Demo使用的Android Studio版本较高,请在Android Studio官网下载最新版的软件,3.4+的都OK

则开启征程之前,确认已经安装了:

Java8

Android Studio 3.4+

有版本较新的Android SDK

有一款Android手机

你说你没有,对不起请不要继续看了

你说你是苹果用户,不好意思,我不是

1.3编译安装

连接好手机和电脑,确认你的手机之前能和Android Studio正常配合吧,有些机型确实需要操作一波才行。我之前的旧版本的手机设置好了开发者选项的USB调试就行了,简单好ROOT,换了新手机后结果人家系统做的很复杂,当然这是出于安全考虑吧,我的小米青春8需要账号允许,设置USB调试,还总是出BUG,最后还是关闭了MIUI优化才完美解决。

说编译不太合适,不过就是这个意思。点击Build-->Run或者点击工具栏的绿色三角按钮(称之为运行按钮),接下来就是等待了。漫长的下载和安装会配置好一切需要的组件,之后就能在手机上跑Demo了。

这个Demo写的还很不错,可以切换CPU和TPU,还能切换三种不同的MobileNet模型,还能访问本机图片资源,还能用相机拍照,涉及了诸多Android开发技术。也可以说是个学习安卓的好教材。


二、构建自己的工程

官方的Demo运行起来了,我想把Paddle Lite迁移到自己的项目上,总不能总是在给出的Demo上改动吧。何况我还有使用自己专用的包名类名的洁癖哪!

下面讲述如何动手从头开始实现类似功能的工程,就不全部实现其功能了,这里仅仅演示如何让自己的工程跑起来,能做图像分类推理。

这里也承接了上一篇文章

注意选择Android版本,要与自己的手机兼容,我的MI Lite 8 是Android8.1版本,最低版本不太高就行,这里选择5.0的版本。

ab4d6fb71474c1b53c358b9897257ead.png

2.2 添加Paddle Lite库文件

把Android视图切换到Project目录,打开对应的app目录,在libs文件夹下(没有就新建一个)添加PaddlePredictor.jar,这个文件就从对应的官方Demo中拷贝过来就行。之后再Android Studio目录中看到自动加入了这个文件,就右键-->Add as Library。选项有点多,不要太眼花缭乱。

d557b9b7aa4ca432a2405875cac34733.png

进一步地,点击OK

b215ed619d92bb0b30378d6942511d89.png

在如下路径下添加assets和jniLibs两个文件夹,assets里放置fluid模型和程序需要调用的图片、文档资源;jniLibs里放Paddle Lite的编译库。

cf2a60ae362ea3a668d7a20732b2f655.png

看看我放了神魔东西?MobileNetV3是我从百度Demo里复制过来的,ResNet18是我的系列教程第一篇里(也就是上一篇文章)从pytorch模型编译出来的预测模型,你可以先只用一个官方给的,好上手,后面再改动也很简单的。

5786ee9df285cd52022e582afb4e5e0b.png

jniLibs里放的是armV7和armV8的编译库,粘贴过来即可。

7566a45d6fecf82308cfd6223844e585.png

2.3 新建Activity

和官方Demo一样的,建立两个Activity,也可以自行选择,按需求而定吧!这里一个是主Activity,另一个是专用图像分类的,这样后续还可以添加更多功能,每个功能都在主程序中提供入口,点击跳转至各自的功能Activity,就和官方Demo一样。

f7853676f1a332710a1b69a1d01e6f70.png

把activity_main.xml的布局改成LinearLayout,适当修改错误就可已成功运行了.

我这里把整体界面布局改成LinearLayout。可以修改官方的界面文件,我这里删除了好多组件,因为只是第一次做,只想着实现功能就好了,以后可以美化一些。

2.4 编写程序

只要之前按照流程添加了PaddlePredictor文件作为库使用,build.gradle就不用了修改了,软件已经自动生成了implementation files('libs/PaddlePredictor.jar')这一行代码,如果你的没有,请加上,并sync一下。注意有两个build.gradle,不要乱改。

下面就只写Java代码,改改UI布局就可以了。

看两个UI文件,从文字中解脱一小灰灰!


主界面UI:activity_main.xml

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
     xmlns:app="http://schemas.android.com/apk/res-auto"
     xmlns:tools="http://schemas.android.com/tools"
    android:orientation="vertical"
     android:layout_width="fill_parent"
     android:layout_height="fill_parent"
     tools:context=".MainActivity">
 
    <ScrollView
            android:layout_width="fill_parent"
            android:layout_height="fill_parent"
            android:fadingEdge="vertical"
            android:scrollbars="vertical">
 
        <LinearLayout
                android:layout_width="fill_parent"
                android:layout_height="fill_parent"
                android:orientation="vertical">
 
            <LinearLayout
                    android:layout_width="fill_parent"
                    android:layout_height="200dp"
                    android:orientation="horizontal">
 
                <RelativeLayout
                        android:id="@+id/v_img_classify"
                        android:layout_width="wrap_content"
                        android:layout_height="fill_parent"
                        android:layout_weight="1"
                        android:clickable="true"
                        android:onClick="onClick">
 
                    <ImageView
                            android:id="@+id/iv_img_classify_icon"
                            android:layout_width="100dp"
                            android:layout_height="100dp"
                            android:layout_centerHorizontal="true"
                            android:layout_centerVertical="true"
                            android:layout_margin="12dp"
                            android:adjustViewBounds="true"
                            android:src="@drawable/image_classfication"
                            android:scaleType="fitCenter"/>
 
                    <TextView
                            android:id="@+id/iv_img_classify_title"
                            android:layout_below="@id/iv_img_classify_icon"
                            android:layout_width="wrap_content"
                            android:layout_height="wrap_content"
                            android:layout_centerHorizontal="true"
                            android:layout_margin="8dp"
                            android:text="Image Classification"
                            android:textStyle="bold"
                            android:textAllCaps="false"
                            android:singleLine="false"/>
                </RelativeLayout>
 
            </LinearLayout>
        </LinearLayout>
    </ScrollView>
</LinearLayout>

图像分类界面UI:activity_classify.xml

<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
     xmlns:app="http://schemas.android.com/apk/res-auto"
     xmlns:tools="http://schemas.android.com/tools"
    android:orientation="vertical"
     android:layout_width="match_parent"
     android:layout_height="match_parent">
 
    <LinearLayout
        android:orientation="horizontal"
        android:layout_weight="4"
        android:layout_width="match_parent"
        android:layout_height="match_parent">
        <TextView
            android:id="@+id/txtInfo"
            android:layout_weight="3"
            android:text="model"
            android:textSize="8pt"
            android:layout_width="match_parent"
            android:layout_height="match_parent" />
 
        <Switch
            android:layout_weight="4"
            android:id="@+id/switch_model"
            android:textOff="Off"
            android:textOn="On"
            android:layout_width="match_parent"
            android:layout_height="match_parent" />
    </LinearLayout>
 
    <ImageView
        android:id="@+id/iv_image"
        android:layout_weight="2"
        android:layout_width="match_parent"
        android:layout_height="match_parent"/>
 
    <TextView
        android:id="@+id/tv_result"
        android:layout_weight="3"
        android:layout_width="match_parent"
        android:layout_height="match_parent"
        android:textStyle="bold"
        android:textSize="12pt"
        android:text="推理结果"/>
 
</LinearLayout>

至于Java程序,我这里有这几个模块,叫模块显得专业,好像学过软件工程的感觉,嗯我学过!

1aa31c0f097f2de39cadbccb4a402616.png

CommonActivity是用于图像分类分类Activity(ImgClassifyActivity.java)的父类;
Predictor是调用fluid模型的基本类;
ImgClassifyPredictor继承之,并定义输入输出大小等细节程序;
然后就是主程序Activity.java和工具类Utils.java;

//CommonActivity.java
public class CommonActivity extends AppCompatActivity {
    private static final String TAG = CommonActivity.class.getSimpleName();
    public static final int OPEN_GALLERY_REQUEST_CODE = 0;
    public static final int TAKE_PHOTO_REQUEST_CODE = 1;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        ActionBar supportActionBar = getSupportActionBar();
        if (supportActionBar != null) {
            supportActionBar.setDisplayHomeAsUpEnabled(true);
        }
    }

    public void onImageChanged(Bitmap image) {}

    @Override
    public boolean onCreateOptionsMenu(Menu menu) {
        MenuInflater inflater = getMenuInflater();
        inflater.inflate(R.menu.menu_action_options, menu);
        return true;
    }

    @Override
    public boolean onOptionsItemSelected(MenuItem item) {
        switch (item.getItemId()) {
            case android.R.id.home:
                finish();
                break;
            case R.id.open_gallery:
                if (requestAllPermissions())openGallery();
                break;
            case R.id.take_photo:
                if (requestAllPermissions())takePhoto();
                break;
        }
        return super.onOptionsItemSelected(item);
    }

    @Override
    public void onRequestPermissionsResult(int requestCode,  String[] permissions, int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        if (grantResults[0] != PackageManager.PERMISSION_GRANTED || grantResults[1] != PackageManager.PERMISSION_GRANTED) {
            Toast.makeText(this, "Permission Denied", Toast.LENGTH_SHORT).show();
        }
    }

    private boolean requestAllPermissions() {
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE)
                != PackageManager.PERMISSION_GRANTED || ContextCompat.checkSelfPermission(this,
                Manifest.permission.CAMERA)
                != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE,
                            Manifest.permission.CAMERA},
                    0);
            return false;
        }
        return true;
    }

    private void openGallery() {
        Intent intent = new Intent(Intent.ACTION_PICK, null);
        intent.setDataAndType(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, "image/*");
        startActivityForResult(intent, OPEN_GALLERY_REQUEST_CODE);
    }

    private void takePhoto() {
        Intent takePhotoIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
        if (takePhotoIntent.resolveActivity(getPackageManager()) != null) {
            startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode == RESULT_OK && data != null) {
            switch (requestCode) {
                case OPEN_GALLERY_REQUEST_CODE:
                    try {
                        ContentResolver resolver = getContentResolver();
                        Uri originalUri = data.getData();
                        Bitmap imageData = MediaStore.Images.Media.getBitmap(resolver, originalUri);
                        String[] proj = {MediaStore.Images.Media.DATA};
                        Cursor cursor = managedQuery(originalUri, proj, null, null, null);
                        cursor.moveToFirst();
                        onImageChanged(imageData);
                    } catch (IOException e) {
                        Log.e(TAG, e.toString());
                    }
                    break;
                case TAKE_PHOTO_REQUEST_CODE:
                    Bundle extras = data.getExtras();
                    Bitmap imageData = (Bitmap) extras.get("data");
                    onImageChanged(imageData);
                    break;
                default:
                    break;
            }
        }
    }

}

下一个:ImgClassifyActivity.java

public class ImgClassifyActivity extends CommonActivity {
    protected TextView tvResult, txtInfo;
    protected ImageView ivImageData;

    protected String imagePath = "cat.jpg";
    protected ImgClassifyPredictor predictor;


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_classify);

        tvResult = findViewById(R.id.tv_result);
        txtInfo = findViewById(R.id.txtInfo);
        ivImageData = findViewById(R.id.iv_image);

        Switch sw = findViewById(R.id.switch_model);

//        load Model
        receiver = new Receiver();
        sender = new Sender();
        sw.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
            @Override
            public void onCheckedChanged(CompoundButton compoundButton, boolean b) {
                if (b) {
                    predictor = new ImgClassifyPredictor("MobileNetV3");
                }else {
                    predictor = new ImgClassifyPredictor("ResNet18");
                }
                sender.sendEmptyMessage(REQUEST_LOAD_MODEL);
            }
        });
        sw.setChecked(true);
    }

    public boolean loadImage() {
        try {
            if (imagePath.isEmpty()) {
                return false;
            }
            Bitmap imageData = null;
            // read test image file from custom path if the first character of mode path is '/', otherwise read test
            // image file from assets
            if (!imagePath.substring(0, 1).equals("/")) {
                InputStream imageStream = getAssets().open(imagePath);
                imageData = BitmapFactory.decodeStream(imageStream);
            } else {
                if (!new File(imagePath).exists()) {
                    return false;
                }
                imageData = BitmapFactory.decodeFile(imagePath);
            }
            if (imageData != null && predictor.isLoaded) {
                predictor.setImageData(imageData);
                return true;
            }
        } catch (IOException e) {
            Toast.makeText(ImgClassifyActivity.this, "Load image failed!", Toast.LENGTH_SHORT).show();
            e.printStackTrace();
        }
        return false;
    }


    @Override
    public boolean onPrepareOptionsMenu(Menu menu) {
        boolean isLoaded = predictor.isLoaded;
        menu.findItem(R.id.open_gallery).setEnabled(isLoaded);
        menu.findItem(R.id.take_photo).setEnabled(isLoaded);
        return super.onPrepareOptionsMenu(menu);
    }


    public void outputResult() {
        Bitmap imageData = predictor.imageData();
        if (imageData != null) {
            ivImageData.setImageBitmap(imageData);
        }
        tvResult.setText(
                predictor.top1Result
                + "n" + predictor.top2Result
                + "n" + predictor.top3Result
                + "nInference time: " + predictor.inferenceTime + " ms");
    }

    @Override
    public void onImageChanged(Bitmap imageData) {
        // rerun model if users pick test image from gallery or camera
        if (imageData != null && predictor.isLoaded) {
            predictor.setImageData(imageData);
            sender.sendEmptyMessage(REQUEST_RUN_MODEL);
        }
        super.onImageChanged(imageData);
    }


    @Override
    protected void onDestroy() {
        if (predictor != null) {
            predictor.releaseModel();
        }
        super.onDestroy();
    }

    //##########################################################################
    public static final int REQUEST_LOAD_MODEL = 0;
    public static final int REQUEST_RUN_MODEL = 1;

    public static final int RESPONSE_LOAD_MODEL_SUCCESS = 0;
    public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
    public static final int RESPONSE_RUN_MODEL_SUCCESS = 2;
    public static final int RESPONSE_RUN_MODEL_FAILED = 3;

    private Receiver receiver;
    private Sender sender;

    class Receiver extends Handler {// receive messages from worker thread
        @Override
        public void handleMessage(Message msg) {
            switch (msg.what) {
                case RESPONSE_LOAD_MODEL_SUCCESS:
                    txtInfo.setText(predictor.modelName+"--load!");
                    if (loadImage())sender.sendEmptyMessage(REQUEST_RUN_MODEL);
                    break;
                case RESPONSE_LOAD_MODEL_FAILED:
                    txtInfo.setText(predictor.modelName+"--fail!");
                    break;
                case RESPONSE_RUN_MODEL_SUCCESS:
                    txtInfo.setText(predictor.modelName+"--run!");
                    // obtain results and update UI
                    outputResult();
                    break;
                case RESPONSE_RUN_MODEL_FAILED:
                    txtInfo.setText(predictor.modelName+"--stop!");
                    break;
                default:
                    break;
            }
        }
    };

    class Sender extends Handler {// send command to worker thread
        public void handleMessage(Message msg) {
            switch (msg.what) {
                case REQUEST_LOAD_MODEL:
                    // load model and reload test image
                    if (predictor.init(getApplication())) {
                        receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_SUCCESS);
                    } else {
                        receiver.sendEmptyMessage(RESPONSE_LOAD_MODEL_FAILED);
                    }
                    break;
                case REQUEST_RUN_MODEL:
                    // run model if model is loaded
                    if (predictor.isLoaded && predictor.runModel()) {
                        receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_SUCCESS);
                    } else {
                        receiver.sendEmptyMessage(RESPONSE_RUN_MODEL_FAILED);
                    }
                    break;
                default:
                    break;
            }
        }
    };
}

下一个:Predictor.java

public class Predictor {
    public boolean isLoaded = false;
    public float inferenceTime = 0;
    public String modelName = "";

    protected Context appCtx = null;
    protected int whichDevice = 0; // 0: CPU 1: NPU
    protected ArrayList<PaddlePredictor> paddlePredictors = new ArrayList<PaddlePredictor>(); // 0: CPU 1: NPU

    public Predictor() {
    }

    public boolean init(Context appCtx, String modelPath) {
        this.appCtx = appCtx;
        isLoaded = loadModel(modelPath);
        return isLoaded;
    }

    protected boolean loadModel(String modelPath) {
        // release model if exists
        releaseModel();

        // load model
        if (modelPath.isEmpty())return false;

        String realPath = modelPath;
        if (!modelPath.substring(0, 1).equals("/")) {
            // read model files from custom path if the first character of mode path is '/'
            // otherwise copy model to cache from assets
            realPath = appCtx.getCacheDir() + "/" + modelPath;
            Utils.copyDirectoryFromAssets(appCtx, modelPath, realPath);
        }
        if (realPath.isEmpty())return false;
        modelName = realPath.substring(realPath.lastIndexOf("/") + 1);

        // run on CPU
        CxxConfig config = new CxxConfig();
        config.setModelDir(realPath);
        Place preferredPlace = new Place(Place.TargetType.ARM, Place.PrecisionType.FLOAT);
        Place[] validPlaces = new Place[2];
        validPlaces[0] = new Place(Place.TargetType.HOST, Place.PrecisionType.FLOAT);
        validPlaces[1] = new Place(Place.TargetType.ARM, Place.PrecisionType.FLOAT);
        config.setPreferredPlace(preferredPlace);
        config.setValidPlaces(validPlaces);
        paddlePredictors.add(PaddlePredictor.createPaddlePredictor(config));
        return true;
    }

    public void releaseModel() {
        paddlePredictors.clear();
        isLoaded = false;
        modelName = "";
    }


    public Tensor getInput(int idx) {
        if (paddlePredictors.size() < whichDevice + 1) {
            return null;
        }
        return paddlePredictors.get(whichDevice).getInput(idx);
    }

    public Tensor getOutput(int idx) {
        if (paddlePredictors.size() < whichDevice + 1) {
            return null;
        }
        return paddlePredictors.get(whichDevice).getOutput(idx);
    }

    public boolean runModel() {
        if (paddlePredictors.size() < whichDevice + 1) {
            return false;
        }
        // warm up
        paddlePredictors.get(whichDevice).run();

        // inference
        Date start = new Date();
        paddlePredictors.get(whichDevice).run();
        Date end = new Date();

        inferenceTime = (end.getTime() - start.getTime());
        return true;
    }
}

之后:ImgClassifyPredictor.java

public class ImgClassifyPredictor extends Predictor {
    private static final String TAG = ImgClassifyPredictor.class.getSimpleName();
    protected Vector<String> wordLabels = new Vector<String>();

    protected Bitmap imageData = null;
    protected String top1Result = "";
    protected String top2Result = "";
    protected String top3Result = "";
    protected float preprocessTime = 0;
    protected float postprocessTime = 0;

    // model config
    protected String modelPath = "MobileNetV3";
    protected String labelPath = "synset_words.txt";
    protected long[] inputShape = new long[]{1,3,224,224};
    protected float[] inputMean = new float[]{0.485f,0.456f,0.406f};
    protected float[] inputStd = new float[]{0.229f,0.224f,0.225f};
    public ImgClassifyPredictor(String modelPath) {
        super();
        this.modelPath = modelPath;
    }

    public boolean init(Context appCtx) {
        if (inputShape.length != 4) {
            Log.i(TAG, "size of input shape should be: 4");
            return false;
        }
        if (inputMean.length != inputShape[1]) {
            Log.i(TAG, "size of input mean should be: " + Long.toString(inputShape[1]));
            return false;
        }
        if (inputStd.length != inputShape[1]) {
            Log.i(TAG, "size of input std should be: " + Long.toString(inputShape[1]));
            return false;
        }
        if (inputShape[0] != 1) {
            Log.i(TAG, "only one batch is supported in the image classification demo, you can use any batch size in " +
                    "your Apps!");
            return false;
        }
        if (inputShape[1] != 1 && inputShape[1] != 3) {
            Log.i(TAG, "only one/three channels are supported in the image classification demo, you can use any " +
                    "channel size in your Apps!");
            return false;
        }
        super.init(appCtx, modelPath);
        if (!super.isLoaded) {
            return false;
        }
        isLoaded &= loadLabel(labelPath);
        this.inputShape = inputShape;
        this.inputMean = inputMean;
        this.inputStd = inputStd;
        return isLoaded;
    }

    protected boolean loadLabel(String labelPath) {
        wordLabels.clear();
        // load word labels from file
        try {
            InputStream assetsInputStream = appCtx.getAssets().open(labelPath);
            int available = assetsInputStream.available();
            byte[] lines = new byte[available];
            assetsInputStream.read(lines);
            assetsInputStream.close();
            String words = new String(lines);
            String[] contents = words.split("n");
            for (String content : contents) {
                int first_space_pos = content.indexOf(" ");
                if (first_space_pos >= 0 && first_space_pos < content.length()) {
                    wordLabels.add(content.substring(first_space_pos));
                }
            }
            Log.i(TAG, "word label size: " + wordLabels.size());
        } catch (Exception e) {
            Log.e(TAG, e.getMessage());
            return false;
        }
        return true;
    }

    public Tensor getInput(int idx) {
        return super.getInput(idx);
    }

    public Tensor getOutput(int idx) {
        return super.getOutput(idx);
    }

    public boolean runModel(Bitmap imageData) {
        setImageData(imageData);
        return runModel();
    }

    public boolean runModel() {
        if (imageData == null) {
            return false;
        }

        // set input shape
        Tensor inputTensor = getInput(0);
        inputTensor.resize(inputShape);

        // pre-process image, and feed input tensor with pre-processed data
        Date start = new Date();
        int channels = (int) inputShape[1];
        int width = (int) inputShape[3];
        int height = (int) inputShape[2];
        float[] inputData = new float[channels * width * height];
        for (int i = 0; i < height; i++) {
            for (int j = 0; j < width; j++) {
                int color = imageData.getPixel(j, i);
                float r = (float) red(color) / 255.0f;
                float g = (float) green(color) / 255.0f;
                float b = (float) blue(color) / 255.0f;
                if (channels == 3) {
                    r = r - inputMean[0];
                    g = g - inputMean[1];
                    b = b - inputMean[2];
                    r = r / inputStd[0];
                    g = g / inputStd[1];
                    b = b / inputStd[2];
                    int rIdx = i * width + j;
                    int gIdx = rIdx + width * height;
                    int bIdx = gIdx + width * height;
                    inputData[rIdx] = r;
                    inputData[gIdx] = g;
                    inputData[bIdx] = b;
                } else { // channels = 1
                    float gray = (b + g + r) / 3.0f;
                    gray = gray - inputMean[0];
                    gray = gray / inputStd[0];
                    inputData[i * width + j] = gray;
                }
            }
        }
        inputTensor.setData(inputData);
        Date end = new Date();
        preprocessTime = (float) (end.getTime() - start.getTime());

        // inference
        super.runModel();

        // fetch output tensor
        Tensor outputTensor = getOutput(0);

        // post-process
        start = new Date();
        long outputShape[] = outputTensor.shape();
        long outputSize = 1;
        for (long s : outputShape) {
            outputSize *= s;
        }
        int[] max_index = new int[3]; // top3 indices
        double[] max_num = new double[3]; // top3 scores
        for (int i = 0; i < outputSize; i++) {
            float tmp = outputTensor.getFloatData()[i];
            int tmp_index = i;
            for (int j = 0; j < 3; j++) {
                if (tmp > max_num[j]) {
                    tmp_index += max_index[j];
                    max_index[j] = tmp_index - max_index[j];
                    tmp_index -= max_index[j];
                    tmp += max_num[j];
                    max_num[j] = tmp - max_num[j];
                    tmp -= max_num[j];
                }
            }
        }
        end = new Date();
        postprocessTime = (float) (end.getTime() - start.getTime());
        if (wordLabels.size() > 0) {
            top1Result = "Top1: " + wordLabels.get(max_index[0]) + " - " + String.format("%.4f", max_num[0]);
            top2Result = "Top2: " + wordLabels.get(max_index[1]) + " - " + String.format("%.4f", max_num[1]);
            top3Result = "Top3: " + wordLabels.get(max_index[2]) + " - " + String.format("%.4f", max_num[2]);
        }
        return true;
    }

    public Bitmap imageData() {
        return imageData;
    }

    public void setImageData(Bitmap imageData) {
        if (imageData == null) {
            return;
        }
        // scale image to the size of input tensor
        Bitmap rgbaData = imageData.copy(Bitmap.Config.ARGB_8888, true);
        Bitmap scaleData = Bitmap.createScaledBitmap(rgbaData, (int) inputShape[3], (int) inputShape[2], true);
        this.imageData = scaleData;
    }
}

最后的主程序:MainActivity.java

public class MainActivity extends AppCompatActivity implements View.OnClickListener {
    private static final String TAG = MainActivity.class.getSimpleName();

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        findViewById(R.id.v_img_classify).setOnClickListener(this);

        Intent intent = new Intent(MainActivity.this, ImgClassifyActivity.class);
        startActivity(intent);
    }


    @Override
    public void onClick(View v) {
        switch (v.getId()) {
            case R.id.v_img_classify:
                Intent intent = new Intent(MainActivity.this, ImgClassifyActivity.class);
                startActivity(intent);
                break;
        }
    }

    @Override
    protected void onDestroy() {
        super.onDestroy();
        System.exit(0);
    }
}

和工具类:Utils.java

public class Utils {
    public static void copyFileFromAssets(Context appCtx, String srcPath, String dstPath) {
        if (srcPath.isEmpty() || dstPath.isEmpty()) {
            return;
        }
        InputStream is = null;
        OutputStream os = null;
        try {
            is = new BufferedInputStream(appCtx.getAssets().open(srcPath));
            os = new BufferedOutputStream(new FileOutputStream(new File(dstPath)));
            byte[] buffer = new byte[1024];
            int length = 0;
            while ((length = is.read(buffer)) != -1) {
                os.write(buffer, 0, length);
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                os.close();
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public static void copyDirectoryFromAssets(Context appCtx, String srcDir, String dstDir) {
        if (srcDir.isEmpty() || dstDir.isEmpty()) {
            return;
        }
        try {
            if (!new File(dstDir).exists()) {
                new File(dstDir).mkdirs();
            }
            for (String fileName : appCtx.getAssets().list(srcDir)) {
                String srcSubPath = srcDir + File.separator + fileName;
                String dstSubPath = dstDir + File.separator + fileName;
                if (new File(srcSubPath).isDirectory()) {
                    copyDirectoryFromAssets(appCtx, srcSubPath, dstSubPath);
                } else {
                    copyFileFromAssets(appCtx, srcSubPath, dstSubPath);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

以上程序都没有包括包名和包导入部分,请自行设置和导入。

最后,让我放肆地贴几张运行截图吧,上一篇文章中你们也看到了一些的。

da5f635f9c7a6fc4291c616f83372100.png

4d6a3475382e1b39e290d182a217fabc.png

2c33be1868062a34342f2601810c3349.png

6a77a65145e0f8bf7a44d3e57444831a.png
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值