前 述
系统的学习tensorflow ,可以从官网:关于TensorFlow | TensorFlow中文官网 (google.cn) 得到详细的讲解以及demo。
1. 对于图像以及目标的识别主要看:TensorFlow Lite 示例 | TensorFlow中文官网 (google.cn)
目前我使用的是对象检测,在android 设备上运行,所以选择第二个:在android 设备上试试。
图像分类:examples/lite/examples/image_classification/android at master · tensorflow/examples (github.com)
对象检测:examples/lite/examples/object_detection/android at master · tensorflow/examples (github.com)
点击链接,将出现android demo
点击example 跳转页面,再点击Download ZIP 即可下载
2. 重点:图像分类与对象检测的区别
图像分类:图像分类 | TensorFlow Lite (google.cn) 都有详细说明的。
对象检测:物体检测 | TensorFlow Lite (google.cn) 都有详细说明的。
图像分类与物体检测的.tflite文件在运行的时候有很大区别:
图像检测 run 函数:tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
inputImageBuffer: image 资料
outputProbabilityBuffer:标签以及概率
因此图像检测属于单输入,单输出函数。
输出中的每个数字都对应训练数据中的一个标签。将我们的输出和这三个训练标签关联,我们能够看出,这个模型预测了这张图片中的对象有很大概率是一条狗。
标签 | 概率 |
---|---|
兔子 | 0.31 |
仓鼠 | 0.35 |
狗 | 0.34 |
public List<Recognition> recognizeImage(final Bitmap bitmap, int sensorOrientation) {
// Logs this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("loadImage");
long startTimeForLoadImage = SystemClock.uptimeMillis();
inputImageBuffer = loadImage(bitmap, sensorOrientation);
long endTimeForLoadImage = SystemClock.uptimeMillis();
Trace.endSection();
Log.v(TAG, "Timecost to load the image: " + (endTimeForLoadImage - startTimeForLoadImage));
// Runs the inference call.
Trace.beginSection("runInference");
long startTimeForReference = SystemClock.uptimeMillis();
tflite.run(inputImageBuffer.getBuffer(), outputProbabilityBuffer.getBuffer().rewind());
long endTimeForReference = SystemClock.uptimeMillis();
Trace.endSection();
Log.v(TAG, "Timecost to run model inference: " + (endTimeForReference - startTimeForReference));
// Gets the map of label and probability.
Map<String, Float> labeledProbability =
new TensorLabel(labels, probabilityProcessor.process(outputProbabilityBuffer))
.getMapWithFloatValue();
Trace.endSection();
// Gets top-k results.
return getTopKProbability(labeledProbability);
}
对象检测 run 函数:tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
因此图像检测属于单输入,多输出函数。
inputArray: image 资料
outputMap:会输出10个物体的坐标位置,标签索引,概率分数以及检测物体数量(一般固定是10)
数组0:outputLocations [1][10][4] 十组数据,每组四个[top, left, bottom, right]这是归一化坐标所以需要乘以图片大小(一般检测图片会按照tflite 训练的大小进行缩小,才能放入 inputArray 中进行run),才是最终的圈起物体的坐标。
数组1:outputClasses[1][10] 十组数据,对应着检测物体的标签索引。
数组2:outputScores [1][10] 十组数据,对应着检测物体的符合概率。
数组3:numDetections [1] 一个数据,对应着检测物体的总和。
注意:
第一个物体的坐标位置:outputLocations [0][0][0] ~ outputLocations [0][0][3] ;第一个物体的标签索引:outputClasses[0][0] ;第一个物体的符合概率:outputScores [0][0]
第二个物体的坐标位置:outputLocations [0][1][0] ~ outputLocations [0][1][3] ;第二个物体的标签索引:outputClasses[0][1] ;第二个物体的符合概率:outputScores [0][1]
............
第十个物体的坐标位置:outputLocations [0][9][0] ~ outputLocations [0][9][3] ;第十个物体的标签索引:outputClasses[0][9] ;第十个物体的符合概率:outputScores [0][9]
以此类推到第十个物体,数组一一对应。
该模型输出四个数组,分别对应索引的 0-3。前三个数组描述10个被检测到的物体,每个数组的最后一个元素匹配每个对象。检测到的物体数量总是10。
索引 | 名称 | 描述 |
---|---|---|
0 | 坐标 | [10][4] 多维数组,每一个元素由 0 到1 之间的浮点数,内部数组表示了矩形边框的 [top, left, bottom, right] |
1 | 类型 | 10个整型元素组成的数组(输出为浮点型值),每一个元素代表标签文件中的索引。 |
2 | 分数 | 10个整型元素组成的数组,元素值为 0 至 1 之间的浮点数,代表检测到的类型 |
3 | 检测到的物体和数量 | 长度为1的数组,元素为检测到的总数 |
类别 | 分数 | 坐标 |
---|---|---|
苹果 | 0.92 | [18, 21, 57, 63] |
香蕉 | 0.88 | [100, 30, 180, 150] |
草莓 | 0.87 | [7, 82, 89, 163] |
香蕉 | 0.23 | [42, 66, 57, 83] |
苹果 | 0.11 | [6, 42, 31, 58] |
int NUM_DETECTIONS = 10;
outputLocations = new float[1][NUM_DETECTIONS][4];
outputClasses = new float[1][NUM_DETECTIONS];
outputScores = new float[1][NUM_DETECTIONS];
numDetections = new float[1];
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
outputMap.put(2, outputScores);
outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Trace.endSection();
3. 下载完成对象检测的android demo,在android studio 中运行。
已运行正常的demo 链接:https://download.csdn.net/download/Chhjnavy/21569431
从github 上下载demo 运行会遇到以下几个问题:
(1) Connection timed out: connect. If you are behind an HTTP proxy, please configure the proxy settings either in IDE or Gradle.
原因分析:TF2_object_detection\app\download_model.gradle 文件中需要下载detect.tflite ,由于外网限制,所以无法下载。
task downloadModelFile(type: Download) {
src 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2?lite-format=tflite'
dest project.ext.ASSET_DIR + '/detect.tflite'
overwrite false
}
preBuild.dependsOn downloadModelFile
解决问题:通过Home | TensorFlow Hub (google.cn) 加速下载,将
src 'https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2?lite-format=tflite'
#更新为
src 'https://hub.tensorflow.google.cn/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1'
再次运行就会发现 src\main\assets\ 已经有 detect.tflite 文件
2)出现第二个问题:The identifier of the model is invalid. The buffer may not be a valid TFLite model flatbuffer.
原因分析:下载下来的detect.tflite 是个无效的文件
解决问题:通过翻墙外网还按照https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/2?lite-format=tflite 下载,下载的文件名字是:lite-model_ssd_mobilenet_v1_1_metadata_2.tflite,将其放入 src\main\assets\ 中。
但是download_model.gradle 依然将src 改成: https://hub.tensorflow.google.cn/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1 ,而且保留 src\main\assets\ 中的 detect.tflite 文件,这样就不会出现Connection timed out 的问题。
3)将demo 中 DetectorActivity.java 以及 DetectorTest.java 中两处加载 detect.tflite 文件的地方修改成 lite-model_ssd_mobilenet_v1_1_metadata_2.tflite :
4. demo 路径:build\outputs\apk\interpreter\debug\ 中有 TF2.apk,直接安装再android 手机上即可运行。
请关注下一篇博客:ubuntu20.10 tensorflow2.5 将训练后的模型移植到android 平台之自己训练模型运行(三)