ubuntu20.10 tensorflow2.5 将训练后的模型移植到android 平台之官网demo 运行(二)

9 篇文章 0 订阅
8 篇文章 1 订阅

前  述

系统的学习tensorflow ,可以从官网:关于TensorFlow | TensorFlow中文官网 (google.cn) 得到详细的讲解以及demo。

1. 对于图像以及目标的识别主要看:TensorFlow Lite 示例 | TensorFlow中文官网 (google.cn)

目前我使用的是对象检测,在android 设备上运行,所以选择第二个:在android 设备上试试

 图像分类:examples/lite/examples/image_classification/android at master · tensorflow/examples (github.com)

对象检测:examples/lite/examples/object_detection/android at master · tensorflow/examples (github.com)

 点击链接,将出现android demo

点击example 跳转页面,再点击Download ZIP 即可下载

2. 重点:图像分类与对象检测的区别

图像分类:图像分类  |  TensorFlow Lite (google.cn) 都有详细说明的。

对象检测:物体检测  |  TensorFlow Lite (google.cn) 都有详细说明的。

图像分类与物体检测的.tflite文件在运行的时候有很大区别:

图像检测 run 函数:tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());

inputImageBufferimage 资料 

outputProbabilityBuffer标签以及概率

因此图像检测属于单输入,单输出函数。

输出中的每个数字都对应训练数据中的一个标签。将我们的输出和这三个训练标签关联,我们能够看出,这个模型预测了这张图片中的对象有很大概率是一条狗。

标签概率
兔子0.31
仓鼠0.35
0.34
  public List<Recognition> recognizeImage(final Bitmap bitmap, int sensorOrientation) {
    // Logs this method so that it can be analyzed with systrace.
    Trace.beginSection("recognizeImage");

    Trace.beginSection("loadImage");
    long startTimeForLoadImage = SystemClock.uptimeMillis();
    inputImageBuffer = loadImage(bitmap, sensorOrientation);
    long endTimeForLoadImage = SystemClock.uptimeMillis();
    Trace.endSection();
    Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage));

    // Runs the inference call.
    Trace.beginSection("runInference");
    long startTimeForReference = SystemClock.uptimeMillis();
    tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
    long endTimeForReference = SystemClock.uptimeMillis();
    Trace.endSection();
    Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference));

    // Gets the map of label and probability.
    Map<String, Float> labeledProbability =
        new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
            .getMapWithFloatValue();
    Trace.endSection();

    // Gets top-k results.
    return getTopKProbability(labeledProbability);
  }

对象检测 run 函数:tfLite.runForMultipleInputsOutputs(inputArray, outputMap);

因此图像检测属于单输入,多输出函数。

inputArrayimage 资料 

outputMap会输出10个物体的坐标位置,标签索引,概率分数以及检测物体数量(一般固定是10)

数组0:outputLocations [1][10][4]   十组数据,每组四个[top, left, bottom, right]这是归一化坐标所以需要乘以图片大小(一般检测图片会按照tflite 训练的大小进行缩小,才能放入 inputArray 中进行run),才是最终的圈起物体的坐标。

数组1:outputClasses[1][10]   十组数据,对应着检测物体的标签索引。

数组2:outputScores [1][10]   十组数据,对应着检测物体的符合概率。

数组3:numDetections [1]   一个数据,对应着检测物体的总和。

注意:

第一个物体的坐标位置:outputLocations [0][0][0]  ~ outputLocations [0][0][3]  ;第一个物体的标签索引:outputClasses[0][0] ;第一个物体的符合概率:outputScores [0][0] 

第二个物体的坐标位置:outputLocations [0][1][0]  ~ outputLocations [0][1][3]  ;第二个物体的标签索引:outputClasses[0][1] ;第二个物体的符合概率:outputScores [0][1] 

............

第十个物体的坐标位置:outputLocations [0][9][0]  ~ outputLocations [0][9][3]  ;第十个物体的标签索引:outputClasses[0][9] ;第十个物体的符合概率:outputScores [0][9] 

以此类推到第十个物体,数组一一对应。

该模型输出四个数组,分别对应索引的 0-3。前三个数组描述10个被检测到的物体,每个数组的最后一个元素匹配每个对象。检测到的物体数量总是10。

索引名称描述
0坐标[10][4] 多维数组,每一个元素由 0 到1 之间的浮点数,内部数组表示了矩形边框的 [top, left, bottom, right]
1类型10个整型元素组成的数组(输出为浮点型值),每一个元素代表标签文件中的索引。
2分数10个整型元素组成的数组,元素值为 0 至 1 之间的浮点数,代表检测到的类型
3检测到的物体和数量长度为1的数组,元素为检测到的总数
类别分数坐标
苹果0.92[18, 21, 57, 63]
香蕉0.88[100, 30, 180, 150]
草莓0.87[7, 82, 89, 163]
香蕉0.23[42, 66, 57, 83]
苹果0.11[6, 42, 31, 58]
    int NUM_DETECTIONS = 10;
    outputLocations = new float[1][NUM_DETECTIONS][4];
    outputClasses = new float[1][NUM_DETECTIONS];
    outputScores = new float[1][NUM_DETECTIONS];
    numDetections = new float[1];

    Object[] inputArray = {imgData};
    Map<Integer, Object> outputMap = new HashMap<>();
    outputMap.put(0, outputLocations);
    outputMap.put(1, outputClasses);
    outputMap.put(2, outputScores);
    outputMap.put(3, numDetections);
    Trace.endSection();

    // Run the inference call.
    Trace.beginSection("run");
    tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
    Trace.endSection();

3. 下载完成对象检测的android demo,在android studio 中运行。

已运行正常的demo 链接:https://download.csdn.net/download/Chhjnavy/21569431

从github 上下载demo 运行会遇到以下几个问题:

(1) Connection timed out: connect. If you are behind an HTTP proxy, please configure the proxy settings either in IDE or Gradle.

原因分析:TF2_object_detection\app\download_model.gradle 文件中需要下载detect.tflite ,由于外网限制,所以无法下载。

task downloadModelFile(type: Download) {
    src 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2?lite-format=tflite'
    dest project.ext.ASSET_DIR + '/detect.tflite'
    overwrite false
}

preBuild.dependsOn downloadModelFile

解决问题:通过Home | TensorFlow Hub (google.cn) 加速下载,将

src 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2?lite-format=tflite'
#更新为
src 'https://hub.tensorflow.google.cn/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1'

 再次运行就会发现 src\main\assets\ 已经有 detect.tflite 文件

2)出现第二个问题:The identifier of the model is invalid. The buffer may not be a valid TFLite model flatbuffer.

原因分析:下载下来的detect.tflite 是个无效的文件

解决问题:通过翻墙外网还按照https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2?lite-format=tflite 下载,下载的文件名字是:lite-model_ssd_mobilenet_v1_1_metadata_2.tflite,将其放入 src\main\assets\ 中。

但是download_model.gradle 依然将src 改成: https://hub.tensorflow.google.cn/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1  ,而且保留 src\main\assets\ 中的 detect.tflite 文件,这样就不会出现Connection timed out 的问题。

3)将demo 中 DetectorActivity.java 以及 DetectorTest.java两处加载 detect.tflite 文件的地方修改成 lite-model_ssd_mobilenet_v1_1_metadata_2.tflite :

4. demo 路径build\outputs\apk\interpreter\debug\ 中有 TF2.apk,直接安装再android 手机上即可运行。

请关注下一篇博客:ubuntu20.10 tensorflow2.5 将训练后的模型移植到android 平台之自己训练模型运行(三)

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值