Tensorflow lite for 移动端安卓开发(三)——移动端测试自己的模型

Tensorflow-lite官方给的应用是一个摄像头demo,主要由ImageClassifier类和Camera2BasicFragment类构成,ImageClassifier类为一个抽象类,由浮点类和数字量化类两类继承,主要实现读取,模型和预测的功能。Camera2BasicFragment类为碎片类,主要实现摄像头的预览功能。基于项目需要,为了能够在移动端测试model的性能,在原demo的基础上开发了一个测试demo,从移动端本地读取测试集进行预测,将预测结果以txt保存在本地,同时计算每类的精确率和召回率在终端显示,先给出demo效果图。
这里写图片描述
这里写图片描述
第一个图展示的是float模型跑出来的结果,第二个图展示的是量化模型的结果Quant量化模型跑出来的结果精度下降很多。
demo的github代码如下:https://github.com/GeekLee95/TFlite_android_test/tree/master
代码主要由以下四个类构成
这里写图片描述
ImageClassifer类 为抽象类
ImageClassifierFloatInception为浮点型子类,对应的浮点模型为assets资源下的7_float.tflite
ImageClaaifierQuantizedMobileNet为量化型子类,对应的数字量化模型为assets资源下的7.tflite
Mainactivity为主活动,主要涉及读取文件,图片格式转化和模型预测等方法。
output_labels.txt为模型的标签文件。

下面介绍主活动的主要方法。

1). public static void verifyStoragePermissions(Activity activity)
该函数实现动态申请权限,android 6.0以后为了提高系统安全,必须要在程序中动态申请权限
首先在清单文件中配置需要申请的权限,

<manifest xmlns:android="http://schemas.android.com/apk/res/android"
    package="com.example.liuli.openfiles">
    <uses-permission android:name="android.permission.MOUNT_UNMOUNT_FILESYSTEM"/>
    <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
    <uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>

然后再动态申请

public static void verifyStoragePermissions(Activity activity){
        try{
            int permission= ActivityCompat.checkSelfPermission(activity,"android.permission.WRITE_EXTERNAL_STORAGE");
            if(permission!= PackageManager.PERMISSION_DENIED){
                ActivityCompat.requestPermissions(activity,PERMISSIONS_STORGE,REQUEST_EXTERNAL_STORAGE);
            }
        } catch (Exception e){
            e.printStackTrace();
        }
   }

2). private List getImagePath()从本地存储中获取测试图片路径,可以选择内部存储(外置SD卡)和扩展存储卡(TF卡)路径。

private List<String> getImagePath(){
        List<String> dirpath = getExtSDCardPathList();
        Log.d("sd_path",dirpath.get(0));
        Log.d("tf_path",dirpath.get(1));
        tfpath = dirpath.get(1);
        List<String> imagePathList = new ArrayList<String>();
        String filepath = tfpath+ File.separator+"DCIM"+File.separator+"TEST";
        //String filepath = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_PICTURES).toString();
        //Context context = getApplicationContext(); //获取当前上下文
        //String filepath = context.getExternalFilesDir("DCIM")+File.separator;
        //得到该路径文件夹下的所有文件
        Log.d("filepath",filepath);
        File fileAll = new File(filepath);
        boolean result = fileAll.exists();
        File[] files = fileAll.listFiles();
        for(int i = 0;i<files.length;i++){
            File file = files[i];
            if(checkIsImageFile(file.getPath())){
                imagePathList.add(file.getPath());
            }
        }
        return imagePathList;
    }

3). private Bitmap createImageThumbnail(String filePath,int newHeight,int newWidth) 将原始图片缩放成指定大小的bitmap格式,比如mobilenet模型的input_size: 224x224

private Bitmap createImageThumbnail(String filePath,int newHeight,int newWidth){

        Bitmap bm = BitmapFactory.decodeFile(filePath);


        float width = bm.getWidth();
        float height = bm.getHeight();
        Log.i("old_size:","宽度是"+width+",高度是"+height);

        Matrix matrix = new Matrix();

        //计算宽高缩放率
        float scaleWidth = ((float) newWidth)/width;
        float scaleHeight = ((float) newHeight)/height;

        //缩放图片动作
        matrix.postScale(scaleWidth,scaleHeight);
        Bitmap bitmap = Bitmap.createBitmap(bm,0,0,(int)width,(int)height,matrix,true);
        Log.i("new_size:","宽度是"+bitmap.getHeight()+",高度是"+bitmap.getWidth());
        return bitmap;

    }

4). private void classifyFrame(List Frames) 进行模型预测

    private void classifyFrame(List<String> Frames){
        int num = 0;
        int carlessnum = 0,carlessTP = 0,carlessFP = 0;
        int carnormalnum = 0,carnormalTP = 0,carnormalFP = 0;
        int carmorenum = 0,carmoreTP = 0,carmoreFP = 0;
        //显示待预测图片总数
        mShownum.setText(Integer.toString(Frames.size()));
        Log.d("mShownum",Integer.toString(Frames.size()));
        String resultfilepath = tfpath+ File.separator+"DCIM"+File.separator+"TESTRESULT"+File.separator;
        for(int i = 0;i<Frames.size();i++){
            String imagepath = Frames.get(i);
            Bitmap bitmap = createImageThumbnail(imagepath,classifier.getImageSizeX(),classifier.getImageSizeY());
            String result = classifier.classifyFrame(bitmap);
            Log.d("Predict_result"+Integer.toString(i),result);
            String imagename = imagepath.split("/")[imagepath.split("/").length-1];

            //将数据保存到本地
            String resultname = imagename.replace(".jpg",".txt");
            Log.d("resultname",resultname);

            writeTxtToFile(result,resultfilepath,resultname);
            String label = imagename.split("_")[0];
            Log.d("label"+Integer.toString(i),label);

            switch (label){
                case "0":
                    carlessnum++;
                    Log.d("carlessnum",Integer.toString(carlessnum));
                    if(result == classifier.labelList.get(Integer.parseInt(label))){
                        carlessTP++;
                        Log.d("carlessTP",Integer.toString(carlessTP));
                    }
                    break;

                case "1":
                    carnormalnum++;
                    Log.d("carnormalnum",Integer.toString(carnormalnum));
                    if(result == classifier.labelList.get(Integer.parseInt(label))){
                        carnormalTP++;
                        Log.d("carnormalTP",Integer.toString(carnormalTP));
                    }
                    break;
                case "2":
                    carmorenum++;
                    Log.d("carmorenum",Integer.toString(carmorenum));
                    if(result == classifier.labelList.get(Integer.parseInt(label))){
                        carmoreTP++;
                        Log.d("carmoreTP",Integer.toString(carmoreTP));
                    }
                    break;
            }

            if(result != classifier.labelList.get(Integer.parseInt(label))){
                switch (result){
                    case "类别1":
                        carlessFP++;
                        break;

                    case "类别2":
                        carnormalFP++;
                        break;
                    case "类别3":
                        carmoreFP++;
                        break;
                }
            }

            if(result == classifier.labelList.get(Integer.parseInt(label))){
                num++;
            } else{
                wrongFrames.add(imagepath+"predict:"+result);
            }
            Log.d("图片数:", Integer.toString(i+1));
            Log.d("正确数:", Integer.toString(num));
        }
        float result  = (float)num/(float)Frames.size();
        mShowResult.setText(Float.toString(result));
        // 计算每一类的精确率和召回率
        float carlessrec = (float)Math.round((float)carlessTP/(float)carlessnum*10000)/10000;
        float carlessacc = (float) Math.round((float)carlessTP/(float)(carlessTP+carlessFP)*10000)/10000;
        float carnormalrec = (float) Math.round((float)carnormalTP/(float)carnormalnum*10000)/10000;
        float carnormalacc = (float) Math.round((float)carnormalTP /(float)(carnormalTP+carnormalFP)*10000)/10000;
        float carmorerec = (float) Math.round((float) carmoreTP/(float)carmorenum*10000)/10000;
        float carmoreacc = (float) Math.round((float)carmoreTP/(float)(carmoreTP+carmoreFP)*10000)/10000;

        mShowcarlessacc.setText(Float.toString(carlessacc));
        mShowcarlessrec.setText(Float.toString(carlessrec));
        mShowcarlessnum.setText(Integer.toString(carlessnum));


        mShowcarnormacc.setText(Float.toString(carnormalacc));
        mShowcarnormrec.setText(Float.toString(carnormalrec));
        mShowcarnormnum.setText(Integer.toString(carnormalnum));

        mShowcarmoreacc.setText(Float.toString(carmoreacc));
        mShowcarmorerec.setText(Float.toString(carmorerec));
        mShowcarmorenum.setText(Integer.toString(carmorenum));
    }

后续将会对模型进行改进和完善。

  • 2
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值