基于TNN在Android手机上实现图像分类

前言

TNN:由腾讯优图实验室打造,移动端高性能、轻量级推理框架,同时拥有跨平台、高性能、模型压缩、代码裁剪等众多突出优势。TNN框架在原有Rapidnet、ncnn框架的基础上进一步加强了移动端设备的支持以及性能优化,同时也借鉴了业界主流开源框架高性能和良好拓展性的优点。

教程源码地址:https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TNNClassification

编译Android库

  1. 安装cmake 3.12
# 卸载旧的cmake
sudo apt-get autoremove cmake

# 下载cmake3.12
wget https://cmake.org/files/v3.12/cmake-3.12.2-Linux-x86_64.tar.gz
tar zxvf cmake-3.12.2-Linux-x86_64.tar.gz

# 移动目录并添加软连接
sudo mv cmake-3.12.2-Linux-x86_64 /opt/cmake-3.12.2
sudo ln -sf /opt/cmake-3.12.2/bin/*  /usr/bin/
  1. 添加Android NDK
wget https://dl.google.com/android/repository/android-ndk-r21b-linux-x86_64.zip
unzip android-ndk-r21b-linux-x86_64.zip
# 添加环境变量,留意你实际下载地址
export ANDROID_NDK=/mnt/d/android-ndk-r21b
  1. 安装编译环境
sudo apt-get install attr
  1. 开始编译
git clone https://github.com/Tencent/TNN.git
cd TNN/scripts

vim build_android.sh
 ABIA32="armeabi-v7a"
 ABIA64="arm64-v8a"
 STL="c++_static"
 SHARED_LIB="ON"                # ON表示编译动态库,OFF表示编译静态库
 ARM="ON"                       # ON表示编译带有Arm CPU版本的库
 OPENMP="ON"                    # ON表示打开OpenMP
 OPENCL="ON"                    # ON表示编译带有Arm GPU版本的库
 SHARING_MEM_WITH_OPENGL=0      # 1表示OpenGL的Texture可以与OpenCL共享

执行编译

./build_android.sh

编译完成后,会在当前目录的release目录下生成对应的armeabi-v7a库,arm64-v8a库和include头文件,这些文件在下一步的Android开发都需要使用到。

模型转换

接下来我们需要把Tensorflow,onnx等其他的模型转换为TNN的模型。目前 TNN 支持业界主流的模型文件格式,包括ONNX、PyTorch、TensorFlow 以及 Caffe 等。TNN 将 ONNX 作为中间层,借助于ONNX 开源社区的力量,来支持多种模型文件格式。如果要将PyTorch、TensorFlow 以及 Caffe 等模型文件格式转换为 TNN,首先需要使用对应的模型转换工具,统一将各种模型格式转换成为 ONNX 模型格式,然后将 ONNX 模型转换成 TNN 模型。

sudo docker pull turandotkay/tnn-convert
sudo docker tag turandotkay/tnn-convert:latest tnn-convert:latest
sudo docker rmi turandotkay/tnn-convert:latest

针对不同的模型转换,有不同的命令,如onnx2tnn,caffe2tnn,tf2tnn。

docker run --volume=$(pwd):/workspace -it tnn-convert:latest  python3 ./converter.py tf2tnn \
    -tp /workspace/mobilenet_v1.pb \
    -in "input[1,224,224,3]" \
    -on MobilenetV1/Predictions/Reshape_1 \
    -v v1.0 \
    -optimize

通过上面的输出,可以发现针对 TF 模型的转换,convert2tnn 工具提供了很多参数,我们一次对下面的参数进行解释:

  • tp 参数(必须)
    通过 “-tp” 参数指定需要转换的模型的路径。目前只支持单个 TF模型的转换,不支持多个 TF 模型的一起转换。
  • in 参数(必须)
    通过 “-in” 参数指定模型输入的名称,输入的名称需要放到“”中,例如,-in “name”。如果模型有多个输入,请使用 “;”进行分割。有的 TensorFlow 模型没有指定 batch 导致无法成功转换为 ONNX 模型,进而无法成功转换为 TNN 模型。你可以通过在名称后添加输入 shape 进行指定。shape 信息需要放在 [] 中。例如:-in “name[1,28,28,3]”。
  • on 参数(必须)
    通过 “-on” 参数指定模型输入的名称,如果模型有多个输出,请使用 “;”进行分割
  • output_dir 参数:
    可以通过 “-o ” 参数指定输出路径,但是在 docker 中我们一般不使用这个参数,默认会将生成的 TNN 模型放在当前和 TF 模型相同的路径下。
  • optimize 参数(可选)
    可以通过 “-optimize” 参数来对模型进行优化,我们强烈建议你开启这个选项,只有在开启这个选项模型转换失败时,我们才建议你去掉 “-optimize” 参数进行重新尝试
  • v 参数(可选)
    可以通过 -v 来指定模型的版本号,以便于后期对模型进行追踪和区分。
  • half 参数(可选)
    可以通过 -half 参数指定,模型数据通过 FP16 进行存储,减少模型的大小,默认是通过 FP32 的方式进行存储模型数据的。
  • align 参数(可选)
    可以通过 -align 参数指定,将 转换得到的 TNN 模型和原模型进行对齐,确定 TNN 模型是否转换成功。当前仅支持单输入单输出模型和单输入多输出模型。 align 只支持 FP32 模型的校验,所以使用 align 的时候不能使用 half
  • input_file 参数(可选)
    可以通过 -input_file 参数指定模型对齐所需要的输入文件的名称,输入需要遵循如下格式
  • ref_file 参数(可选)
    可以通过 -ref_file 参数指定待对齐的输出文件的名称,输出需遵循如下格式。生成输出的代码可以参考

成功转换会输出以下的日志。

----------  convert model, please wait a moment ----------

Converter Tensorflow to TNN model

Convert TensorFlow to ONNX model succeed!

Converter ONNX to TNN Model

Converter ONNX to TNN model succeed!

最终会得到这两个模型文件,mobilenet_v1.opt.tnnmodel mobilenet_v1.opt.tnnproto

开发Android项目

  1. 将转换的模型放在assets目录下。
  2. 把上一步编译得到的include目录复制到Android项目的app目录下。
  3. 把上一步编译得到的armeabi-v7aarm64-v8a目录复制到main/jniLibs下。
  4. app/src/main/cpp/目录下编写JNI的C++代码。

TNN工具

编写一个ImageClassifyUtil.java工具类,关于TNN的操作都在这里完成,如加载模型、预测。

下面三个就是TNN的JNI接口,通过这个接口完成模型加载,预测,当不使用的时候和可以调用deinit()清空对象。

public native int init(String modelPath, String protoPath, int computeUnitType);

public native float[] predict(Bitmap image, int width, int height);

public native int deinit();

通过上面的JNI接口,下面就可以实现图像识别了,WIDTHHEIGHT是模型输入图片的大小。为了兼容图片路径和Bitmap格式的图片预测,这里创建了两个重载方法。

private static final int WIDTH = 224;
private  static final int HEIGHT = 224;

public ImageClassifyUtil() {
    System.loadLibrary("TNN");
    System.loadLibrary("tnn_wrapper");
}

// 重载方法,根据图片路径转Bitmap预测
public float[] predictImage(String image_path) throws Exception {
    if (!new File(image_path).exists()) {
        throw new Exception("image file is not exists!");
    }
    FileInputStream fis = new FileInputStream(image_path);
    Bitmap bitmap = BitmapFactory.decodeStream(fis);
    Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, WIDTH, HEIGHT, false);
    float[] result = predictImage(scaleBitmap);
    if (bitmap.isRecycled()) {
        bitmap.recycle();
    }
    return result;
}

// 重载方法,直接使用Bitmap预测
public float[] predictImage(Bitmap bitmap) {
    Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, WIDTH, HEIGHT, false);
    float[] results = predict(scaleBitmap, WIDTH, HEIGHT);
    int l = getMaxResult(results);
    return new float[]{l, results[l] * 0.01f};
}

这里创建一个获取最大概率值,并把下标返回的方法,其实就是获取概率最大的预测标签。

public static int getMaxResult(float[] result) {
    float probability = 0;
    int r = 0;
    for (int i = 0; i < result.length; i++) {
        if (probability < result[i]) {
            probability = result[i];
            r = i;
        }
    }
    return r;
}

不同的模型,训练的预处理方式可能不一样,TNN 的图像预处理在C++中完成,代码片段

TNN_NS::MatConvertParam input_cvt_param;
input_cvt_param.scale = {1.0 / (255 * 0.229), 1.0 / (255 * 0.224), 1.0 / (255 * 0.225), 0.0};
input_cvt_param.bias  = {-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225, 0.0};
auto status = instance_->SetInputMat(input_mat, input_cvt_param);

选择图片预测

本教程会有两个页面,一个是选择图片进行预测的页面,另一个是使用相机实时预测并显示预测结果。以下为activity_main.xml的代码,通过按钮选择图片,并在该页面显示图片和预测结果。

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".MainActivity">

    <ImageView
        android:id="@+id/image_view"
        android:layout_width="match_parent"
        android:layout_height="400dp" />

    <TextView
        android:id="@+id/result_text"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_below="@id/image_view"
        android:text="识别结果"
        android:textSize="16sp" />


    <LinearLayout
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:orientation="horizontal">

        <Button
            android:id="@+id/select_img_btn"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="选择照片" />


        <Button
            android:id="@+id/open_camera"
            android:layout_width="0dp"
            android:layout_height="wrap_content"
            android:layout_weight="1"
            android:text="实时预测" />

    </LinearLayout>

</RelativeLayout>

MainActivity.java中,进入到页面我们就要先加载模型,我们是把模型放在Android项目的assets目录的,我们需要把模型复制到一个缓存目录,然后再从缓存目录加载模型,同时还有读取标签名,标签名称按照训练的label顺序存放在assets的label_list.txt,以下为实现代码。

classNames = Utils.ReadListFromFile(getAssets(), "label_list.txt");
String protoContent = getCacheDir().getAbsolutePath() + File.separator + "squeezenet_v1.1.tnnproto";
Utils.copyFileFromAsset(MainActivity.this, "squeezenet_v1.1.tnnproto", protoContent);
String modelContent = getCacheDir().getAbsolutePath() + File.separator + "squeezenet_v1.1.tnnmodel";
Utils.copyFileFromAsset(MainActivity.this, "squeezenet_v1.1.tnnmodel", modelContent);

imageClassifyUtil = new ImageClassifyUtil();
int status = imageClassifyUtil.init(modelContent, protoContent, USE_GPU ? 1 : 0);
if (status == 0){
    Toast.makeText(MainActivity.this, "模型加载成功!", Toast.LENGTH_SHORT).show();
}else {
    Toast.makeText(MainActivity.this, "模型加载失败!", Toast.LENGTH_SHORT).show();
    finish();
}

添加两个按钮点击事件,可以选择打开相册读取图片进行预测,或者打开另一个Activity进行调用摄像头实时识别。

Button selectImgBtn = findViewById(R.id.select_img_btn);
Button openCamera = findViewById(R.id.open_camera);
imageView = findViewById(R.id.image_view);
textView = findViewById(R.id.result_text);
selectImgBtn.setOnClickListener(new View.OnClickListener() {
    @Override
    public void onClick(View v) {
        // 打开相册
        Intent intent = new Intent(Intent.ACTION_PICK);
        intent.setType("image/*");
        startActivityForResult(intent, 1);
    }
});
openCamera.setOnClickListener(new View.OnClickListener() {
    @Override
    public void onClick(View v) {
        // 打开实时拍摄识别页面
        Intent intent = new Intent(MainActivity.this, CameraActivity.class);
        startActivity(intent);
    }
});

当打开相册选择照片之后,回到原来的页面,在下面这个回调方法中获取选择图片的Uri,通过Uri可以获取到图片的绝对路径。如果Android8以上的设备获取不到图片,需要在AndroidManifest.xml配置文件中的application添加android:requestLegacyExternalStorage="true"。拿到图片路径之后,调用TFLiteClassificationUtil类中的predictImage()方法预测并获取预测值,在页面上显示预测的标签、对应标签的名称、概率值和预测时间。

@Override
protected void onActivityResult(int requestCode, int resultCode, @Nullable Intent data) {
    super.onActivityResult(requestCode, resultCode, data);
    String image_path;
    if (resultCode == Activity.RESULT_OK) {
        if (requestCode == 1) {
            if (data == null) {
                Log.w("onActivityResult", "user photo data is null");
                return;
            }
            Uri image_uri = data.getData();
            image_path = getPathFromURI(MainActivity.this, image_uri);
            try {
                // 预测图像
                FileInputStream fis = new FileInputStream(image_path);
                imageView.setImageBitmap(BitmapFactory.decodeStream(fis));
                long start = System.currentTimeMillis();
                float[] result = imageClassifyUtil.predictImage(image_path);
                long end = System.currentTimeMillis();
                String show_text = "预测结果标签:" + (int) result[0] +
                        "\n名称:" +  classNames[(int) result[0]] +
                        "\n概率:" + result[1] +
                        "\n时间:" + (end - start) + "ms";
                textView.setText(show_text);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

上面获取的Uri可以通过下面这个方法把Url转换成绝对路径。

// get photo from Uri
public static String getPathFromURI(Context context, Uri uri) {
    String result;
    Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
    if (cursor == null) {
        result = uri.getPath();
    } else {
        cursor.moveToFirst();
        int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
        result = cursor.getString(idx);
        cursor.close();
    }
    return result;
}

摄像头实时预测

在调用相机实时预测我就不再介绍了,原理都差不多,具体可以查看https://github.com/yeyupiaoling/ClassificationForAndroid/tree/master/TFLiteClassification中的源代码。核心代码如下,创建一个子线程,子线程中不断从摄像头预览的AutoFitTextureView上获取图像,并执行预测,并在页面上显示预测的标签、对应标签的名称、概率值和预测时间。每一次预测完成之后都立即获取图片继续预测,只要预测速度够快,就可以看成实时预测。

private Runnable periodicClassify =
        new Runnable() {
            @Override
            public void run() {
                synchronized (lock) {
                    if (runClassifier) {
                        // 开始预测前要判断相机是否已经准备好
                        if (getApplicationContext() != null && mCameraDevice != null && mnnClassification != null) {
                            predict();
                        }
                    }
                }
                if (mInferThread != null && mInferHandler != null && mCaptureHandler != null && mCaptureThread != null) {
                    mInferHandler.post(periodicClassify);
                }
            }
        };

// 预测相机捕获的图像
private void predict() {
    // 获取相机捕获的图像
    Bitmap bitmap = mTextureView.getBitmap();
    try {
        // 预测图像
        long start = System.currentTimeMillis();
        float[] result = imageClassifyUtil.predictImage(bitmap);
        long end = System.currentTimeMillis();
        String show_text = "预测结果标签:" + (int) result[0] +
                "\n名称:" +  classNames[(int) result[0]] +
                "\n概率:" + result[1] +
                "\n时间:" + (end - start) + "ms";
        textView.setText(show_text);
    } catch (Exception e) {
        e.printStackTrace();
    }
}

本项目中使用的了读取图片的权限和打开相机的权限,所以不要忘记在AndroidManifest.xml添加以下权限申请。

<uses-permission android:name="android.permission.CAMERA"/>
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>

如果是Android 6 以上的设备还要动态申请权限。

    // check had permission
    private boolean hasPermission() {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
            return checkSelfPermission(Manifest.permission.CAMERA) == PackageManager.PERMISSION_GRANTED &&
                    checkSelfPermission(Manifest.permission.READ_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED &&
                    checkSelfPermission(Manifest.permission.WRITE_EXTERNAL_STORAGE) == PackageManager.PERMISSION_GRANTED;
        } else {
            return true;
        }
    }

    // request permission
    private void requestPermission() {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
            requestPermissions(new String[]{Manifest.permission.CAMERA,
                    Manifest.permission.READ_EXTERNAL_STORAGE,
                    Manifest.permission.WRITE_EXTERNAL_STORAGE}, 1);
        }
    }

效果图:
在这里插入图片描述

  • 5
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 19
    评论
以下是一个基于TNN的C++图像分类代码示例: ```cpp #include <iostream> #include <memory> #include "tnn/core/macro.h" #include "tnn/core/tnn.h" #include "tnn/utils/mat_utils.h" #include "tnn/utils/dims_vector_utils.h" #include "tnn/utils/blob_converter.h" #include "tnn/interpreter/default_model_interpreter.h" #include "tnn/interpreter/tnn/tnn_interpreter.h" #include "tnn/interpreter/tnn/tnn_device.h" #include "tnn/interpreter/tnn/tnn_utils_internal.h" using namespace TNN_NS; int main(int argc, char** argv) { if (argc < 3) { std::cout << "Usage: " << argv[0] << " proto model image" << std::endl; return -1; } std::string proto_content, model_content; if (ReadProtoFile(argv[1], &proto_content) != TNN_OK) { std::cout << "Read proto file failed." << std::endl; return -1; } if (ReadModelFile(argv[2], model_content) != TNN_OK) { std::cout << "Read model file failed." << std::endl; return -1; } std::shared_ptr<TNN> tnn = std::make_shared<TNN>(); TNNStatus status = tnn->Init(proto_content, model_content, "", TNNComputeUnitsCPU); if (status != TNN_OK) { std::cout << "Init TNN failed, error code: " << (int)status << std::endl; return -1; } auto input_dims = tnn->GetInputShape(0); if (input_dims.size() != 4) { std::cout << "Invalid input dims." << std::endl; return -1; } auto input_mat = std::make_shared<Mat>(input_dims, MatType::NCHW_FLOAT); int input_size = DimsVectorUtils::Count(input_dims); auto converter = std::make_shared<NCBlobConverter>(); size_t input_bytes_size = input_size * sizeof(float); RawBuffer input_buffer(input_bytes_size); std::ifstream in_file(argv[3], std::ios::binary); if (!in_file.is_open()) { std::cout << "Open image file failed." << std::endl; return -1; } in_file.read(reinterpret_cast<char*>(input_buffer.force_to<void*>()), input_bytes_size); in_file.close(); converter->ConvertFromHostToDevice(input_buffer, input_mat, nullptr); std::shared_ptr<TNNInterpreter> interpreter = std::make_shared<TNNInterpreter>(); status = interpreter->Init(tnn->GetModelConfig()); if (status != TNN_OK) { std::cout << "Init interpreter failed, error code: " << (int)status << std::endl; return -1; } std::shared_ptr<TNNSession> session = interpreter->CreateSession(tnn->GetModelConfig()); if (!session) { std::cout << "Create session failed." << std::endl; return -1; } status = session->SetInputMat(input_mat); if (status != TNN_OK) { std::cout << "Set input mat failed, error code: " << (int)status << std::endl; return -1; } status = session->Forward(); if (status != TNN_OK) { std::cout << "Forward failed, error code: " << (int)status << std::endl; return -1; } auto output_dims = tnn->GetOutputShape(0); if (output_dims.size() != 2) { std::cout << "Invalid output dims." << std::endl; return -1; } auto output_size = DimsVectorUtils::Count(output_dims); auto output_mat = std::make_shared<Mat>(output_dims, MatType::NCHW_FLOAT); auto output_buffer = std::make_shared<RawBuffer>(output_size * sizeof(float)); converter->ConvertFromDeviceToHost(output_mat, output_buffer, nullptr); status = session->GetOutputMat(output_mat); if (status != TNN_OK) { std::cout << "Get output mat failed, error code: " << (int)status << std::endl; return -1; } std::vector<float> output_data(output_size); memcpy(output_data.data(), output_buffer->force_to<void*>(), output_size * sizeof(float)); int max_index = std::max_element(output_data.begin(), output_data.end()) - output_data.begin(); std::cout << "Predicted class index: " << max_index << std::endl; return 0; } ``` 这段代码会读取一个模型文件和一个图像文件,并使用TNN进行图像分类。注意,这段代码仅用于参考,实际使用时可能需要根据具体情况进行修改。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

夜雨飘零1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值