详解 TensorFlow TFLite 移动端(安卓)部署物体检测 demo(2)——量化模型

写在前面

上篇写了使用 TensorFlow 提供的 examples 项目在安卓手机移动端部署一个简单的物体检测 demo。写到换自定义模型部署我给跑了,现在回来接着说(/(ㄒoㄒ)/~~) 上篇写到了想要换成自己的模型大概的步骤,也就是先训自己的模型,然后要 freeze 它,然后 convert 成 tflite 格式,最后再把 labelmap 信息作为 metadata 写入模型输出最后的 tflite放到手机上运行。这几句话说得很笼统,但是想先往前走一走,所以先不说那么详尽了(说的多错的多23333)。

这个 demo 下会先给一个现成可拿来测试的、平时 PC 上训练得到的模型,所以我们可以先不说训练模型的事情。我们就先用这个现成的,主要走后面那三步,也就是 freeze、 convert 和 metadata 三个步骤。

我上一篇中已经说过了,到目前为止 tensorflow 的文档没有全部更新,所以造成很多重要引导信息不能直接获取到。截止到现在,app 下 README 文件还没给最后 metadata 这一步骤以及对应的简要代码(tensorflow 已经提供了这部分代码)。所以仍然照旧,只想看流程的可以直接看“上帝视角”,还想看看我曾经如何找路的可以看“地图探索”

老规矩,害怕把人吓走,先放点结果看看。本篇里面换模型了呦,用的 oid 数据集,里面有一类物体是 football 足球,这在上一篇模型(COCO)数据集里是没有的哦。(我没有足球,这是显示屏,开的百度搜索到的足球图片233333)

在这里插入图片描述

那那那那可以直接开始 freeze,convert 和 metadata 了吗?

不可以。

下载并安装 models

(1)“上帝视角” 下载 models:
完成这两个步骤就需要先下载并安装上篇提到的 model 项目了(本篇中 examples 项目权重就占得少了,主要是 models 项目的介绍比较多)。

cd /path/to/put/models
https://github.com/tensorflow/models.git

models 这个项目下面是官方还有很多研究者提供的模型相关资料。我们最关心的部分是 models/research/object_detection ,这个文件夹下的内容就是物体检测的基本工具和信息,以及 freeze 和 convert 相关的代码。

(2)安装 models:
“上帝视角”:反正我都试了,不安装 models 也行(忙活半天,哎……tensorflow 官方文档真的害人不浅啊/(ㄒoㄒ)/~~)

“地图探索”:此步骤的经历比较曲折,去年做的时候,我是老老实实按照引导来的。之前的 models/research 文件夹下是有一个 setup.py 文件的,通过它就可以安装 object_detection 。 examples 项目那边 app 路径下面的 README 仍然还保留着这个安装的步骤,但是目前最新的 models/research 下面没这个 setup.py 文件了。我不确定是这个 demo 整体更新过不再需要这一步,还是单纯地缺失了 setup.py 文件。整得我有点懵,所以我进行了两种尝试


2021.07 更新
关于 setup 文件“丢失”情况,在 /models/research/object_detection/g3doc/ 路径下 文件 tf1.md 和 tf2.md 中有此 API 的安装方式,大同小异。


尝试安装:我复制了以前版本 models/research 路径下的 setup.py 文件,并且照着以前的安装方式,也成功安装了。总之最后会在你的环境里添加 object_detection1.0 这个包,具体安装命令比较简单,但是里面实际放生步骤比较多,因为想清晰知道发生了什么,我特意根据输出记了一下,有点长,单独放在这里

尝试不安装:我不安装,跳过这一步,为了验证,我还特意建了全新的环境,仍然可以完成整个 freeze 和 convert 过程。所以,emmm 只能说,不安装 models 这一步,仍然可以完成 freeze 和 convert,但是会不会有别的隐患,我目前的二次梳理还没有遇到,因为 freeze 和 convert 也只是整个 modles 项目中的其中一部分功能而已,目前我不确定其他别的使用会不会受到此步骤的影响。

测试不安装 models 新建的环境,我装的 tf1,原因下一段写:

conda create -n lxdpy364tflite python=3.6
conda install tensorflow-gpu=1.14
pip install tf_slim
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple/ matplotlib

是时候操心一下环境了

tf 版本用 v1 还是用 v2?

“上帝视角”:取决于你到底用哪个模型,我这个记录里面用的那个 ssd_mobilenet_v2 必须用 tf1。

tf1 和 tf2 分别支持的模型,在 models/research/object_detection/g3doc 下的 tf1_detection_zoo.mdtf2_detection_zoo.md 文件中有详细的列表介绍。这点是无法通过 tf2 compat v1 替代的。

“地图探索”:插一句有用的废话,按照我之前(去年)的路程,这个 models 里面有很多 tf 1 和 2 之间的坑……我当时每天上班都骂骂咧咧的,就是吧,感觉它在 1 和 2 之间好像能无缝切换,但反正根本不是这样。

最后,(针对 app 里面介绍 freeze 和 convert 提供的 ssd_mobilenetv2_oidv4)我是在 tfv1 环境下跑通的,是因为 ssd_mobilenet_v2 仅仅在 tfv1 的模型种类构建范围内(这个后面再说下,截止目前 tf1 和 tf2 的范围是不一样的)。

然后现在我看这个 models 项目,好像在 tf 1 和 2 兼容上又做工作了?我以为我可以直接使用 tf2 了?毕竟官方都说了鼓励使用 tf2,我就信了他们的邪……我就是 TMD 记吃不记打/(ㄒoㄒ)/~~

所以我用的 v2 版本跑了一遍,一路处理下去依次解决问题后,仍然卡在 tfv2 不提供 ssd_mobilenetv2_oidv4 的模型构建功能,针对 ssd_mobilenet_v2 只能用 tfv1。

如果你用 tf2 跑这个 ssd_mobilenet_v2 的话,即使你一路下去解决各种小报错,最终也会卡在这里,报错的最后一行也有明确说明,这个模型根本不在目前 tf2 支持的范围内:

 File "object_detection/export_tflite_ssd_graph.py", line 140, in main
    FLAGS.max_classes_per_detection, use_regular_nms=FLAGS.use_regular_nms)
  File "/home/yx-lxd/mobile_work/tensorflow_models/research/object_detection/export_tflite_ssd_graph_lib.py", line 246, in export_tflite_graph
    pipeline_config.model, is_training=False)
  File "/home/yx-lxd/mobile_work/tensorflow_models/research/object_detection/builders/model_builder.py", line 1227, in build
    add_summaries)
  File "/home/yx-lxd/mobile_work/tensorflow_models/research/object_detection/builders/model_builder.py", line 391, in _build_ssd_model
    _check_feature_extractor_exists(ssd_config.feature_extractor.type)
  File "/home/yx-lxd/mobile_work/tensorflow_models/research/object_detection/builders/model_builder.py", line 265, in _check_feature_extractor_exists
    'Tensorflow'.format(feature_extractor_type))
ValueError: ssd_mobilenet_v2 is not supported. See `model_builder.py` for features extractors compatible with different versions of Tensorflow
# 最后一行已经说明问题,它不在 tf2 的支持范围内

下载现成的 ckpt 模型

“上帝视角”:本小标题下都是必要的步骤,请放心食用~

这里下载刚刚提到的模型,就是下图选中的那个 ssd_mobilenetv2_oidv4,科学上网。我把这个模型放在网盘了,g5yf。

在这里插入图片描述
解压之后长这样,ckpt 格式的,马上就可以拿来用了。
在这里插入图片描述
这个模型是在数据集 oid 上训练的,所以 labelmap 要换,在这里下载,记得取出来第二列单独准备在 labelmap.txt 文件里面,后面会和这个模型搭配使用。这个数据集 lablemap 信息前几项长这样:

Tortoise
Container
Magpie
Sea turtle
Football
Ambulance

你看这里有个 Football,COCO 数据集里没有的。

现在能开始 freeze, convert 和 metadata 了吗?

emmm,能,只能开一点点,不能开多了。来自 b 站 几加乘,给我过审吧/(ㄒoㄒ)/~~
在这里插入图片描述

开始 freeze

代码

到现在这个进度,我们下一步的目标是实现 freeze,freeze 只是固化模型,不会改变模型内参数的数值和类型,因此也不会改变模型大小,代码是:

cd /models/research

python object_detection/export_tflite_ssd_graph.py --pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v2_oid_v4.config --trained_checkpoint_prefix /home/yx-lxd/mobile_work/ssd_mobilenet_v2_oid_v4_2018_12_12/model.ckpt --output_directory /home/yx-lxd/mobile_work/
  • object_detection/export_tflite_ssd_graph.py 这里是想运行这个代码;
  • --pipeline_config_path models 项目是给了的,这点没问题,就是 /models/research/object_detection/samples/configs/ssd_mobilenet_v2_oid_v4.config
  • --trained_checkpoint_prefix 这点也没问题,就是刚刚下载的那个待使用的现成的模型;
  • --output_directory,这点更没问题,自己设置的输出路径,这个代码执行之后,会 freeze ckpt 模型,生成 .pb 和 .pbtxt 两个文件。
报错处理

但是吧,我之所以说能开始 freeze 但只能开一点点的原因是,此时运行上面这个代码,会报错的呀,而且不止一处哦……我第一次跑通就只看报错信息慢慢修改的。现在整理一下报错可以分为几种情况吧:

(1)tf 1 和 2 之间版本的问题;
(2)本来就有一些细节,在 app 下的 README 没有完全讲清楚,需要再跳转到相关的文档里查看。
(3)给的文档里面也有猝不及防完全没有提及的,根据报错信息解决就好了。

报错举例:
(1)某些代码中找不到包的,比如我曾经遇到过的有 official, nets 等,对应添加环境变量即可,其中 official 在 /models 下面,所以添加 /models;nets 在 /models/research/slim 下面,所以添加 /models/research/slim
(2)找不到 xxx.pb2.py 文件,cd /models/research protoc object_detection/protos/*.proto --python_out=.
(3)from tensorflow.python.keras.applications import resnet 没有 resnet。

File "/home/yx-lxd/mobile_work/tensorflow_models/research/object_detection/models/keras_models/resnet_v1.py", line 22, in <module>
    from tensorflow.python.keras.applications import resnet
# 将 resnet 及其调用改为 resnet50,tfv1 下就是没有 `resnet`,只有 `resnet50`,此处替换可以跑通

上面问题都解决,或者灵活根据报错改正之后,再运行 python object_detection/export_tflite_ssd_graph.py --pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v2_oid_v4.config --trained_checkpoint_prefix /home/yx-lxd/mobile_work/ssd_mobilenet_v2_oid_v4_2018_12_12/model.ckpt --output_directory /home/yx-lxd/mobile_work/

得到如下,即为 freeze 成功。

# freeze 成功
I0628 14:04:20.163475 140413095548672 graph_util_impl.py:311] Froze 324 variables.
INFO:tensorflow:Converted 324 variables to const ops.
I0628 14:04:20.240427 140413095548672 graph_util_impl.py:364] Converted 324 variables to const ops.
2021-06-28 14:04:20.341241: I tensorflow/tools/graph_transforms/transform_graph.cc:317] Applying strip_unused_nodes

扩展

灵魂质问:作业做完了吗? freeze 这就完了吗?里面到底怎么做?是不是应该详细看一下 export_tflite_ssd_graph.py
拒绝挨打:我真的做作业了鸭,细节太多篇幅太长了鸭,放在TensorFlow TFLite 移动端(安卓)部署中 freeze 的细节了鸭~

开始 convert

压缩模型,将模型中的参数类型改变生成能够在移动端运行的 tflite 就在这一步。根据你是否要将参数量化存入 tflite 可以分为两种情况,即量化和不量化。

也就是说,在这一步你不量化也能生成 tflite 格式的模型用于移动端,但是模型文件大小基本不变,运算资源也会使用得更多。如果成功实现量化,会看到输出的 tflite 模型文件基本是原来 pb 文件的 1/4。

但根据我的测试,无论这里是否模型大小降为 1/4, Android Studio 打开的项目中参数 TF_OD_API_IS_QUANTIZED 都应该设置为 false。此参数控制的是输入数据的情况,也就是接收 0-255 uint8 的图像数据还是 float32 且归一化后的输入数据。

代码

据我了解有三种方式,但是本质上应该是一样的。

(1)根据 app 下 README 文件的引导,下面一行就可以直接实现 convert pb 文件到 tflite 文件。

tflite_convert --input_shape=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3 --allow_custom_ops --graph_def_file=/home/yx-lxd/mobile_work/tflite_graph.pb --output_file=/home/yx-lxd/mobile_work/detect.tflite
  • input_shape 是模型接受的输入大小,[batchsize, heigh, width, channel];
  • input_arrays 是模型接受输入的 op 名字,此处 demo 给了这个模型现成的,就是 normalized_input_image_tensor
  • output_arrays 也是模型的输出,此处有 4 个,demo 给的现成的,依次是 TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3
  • allow_custom_ops 允许自定义 op;
  • --graph_def_file 上一步刚刚 freeze 的 pb 模型;
  • output_file 输出 tflite 的路径。

这段是 app 下 README 文件给的代码示例,其中没有提到要量化的事情,所以生成的 tflite 模型文件大小几乎不变(如下图选中的两个文件),对应的 TF_OD_API_IS_QUANTIZED 应为 false。

如果你想用这种方式实现量化,你可以在代码里面加上量化的设置 --post_training_quantize 即可,只需要加上,不需要设置 True。这样你就能得到一个 14.6 M 的 tflite 模型文件(pb 是 58.1 M)。

(2)也可以写个脚本来实现同样的事情,是否量化只需要更改设置 converter.post_training_quantize

import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
# Define pb file to convert
graph_def_file = "/home/yx-lxd/mobile_work/tflite_graph.pb"

# input_arrays has been determined if you get the pb file by export_tflite_ssd_graph.py 
input_arrays = ["normalized_input_image_tensor"] 
# If 'allow_custom_ops = True', length should be 4
output_arrays = \
['TFLite_Detection_PostProcess', \
'TFLite_Detection_PostProcess:1', \
'TFLite_Detection_PostProcess:2', \
'TFLite_Detection_PostProcess:3']

converter =  \
tf.lite.TFLiteConverter.from_frozen_graph(graph_def_file,  \
input_arrays,  \
output_arrays,  \
input_shapes={"normalized_input_image_tensor":[1,300,300,3]})	# here 300 * 300 is the input size

converter.allow_custom_ops=True
# If you want to quantize the model, set it True; else, set it False
converter.post_training_quantize = True 
tflite_model = converter.convert()

open("/home/yx-lxd/mobile_work/detect.tflite", "wb").write(tflite_model)

(3)还有一种用 bazel 方式的,我没有实际运行,就不在这里贴代码了,来源于 tensorflow 文档,在 /models/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md 中,可以自行查看。

结果
# 并没有打印成功等信息,但没有报错,在对应路径下,可以找到 tflite 文件。
2021-06-28 14:17:04.157686: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x4fb1350 executing computations on platform CUDA. Devices:
2021-06-28 14:17:04.157698: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): GeForce GTX 1060 6GB, Compute Capability 6.1
扩展

如果又换别的模型了,前面那些 input_arrays 和 output_arrays 怎么设置呢?

其实 convert 步骤中的 input_arrays 和 output_arrays 是在上一个步骤 freeze 中相关的代码中设定的,也就是自定义的,即 export_tflite_ssd_graph.py 相关代码中,细节同样请看TensorFlow TFLite 移动端(安卓)部署中 freeze 的细节

最后的 metadata

“上帝视角”:直接贴代码,来源在此。 /(ㄒoㄒ)/~~为什么文档不直接指引到这里

from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils

ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "/home/yx-lxd/mobile_work/ssd_mobilenet_v2_oid_v4_2018_12_12/detect.tflite"

# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "/home/yx-lxd/mobile_work/ssd_mobilenet_v2_oid_v4_2018_12_12/labelmap.txt"
_SAVE_TO_PATH = "/home/yx-lxd/mobile_work/ssd_mobilenet_v2_oid_v4_2018_12_12/ssd_mobilenet_v2_metadata.tflite"

# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters)

_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5

# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
    [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

把这个步骤生成的 tflite 模型 + 前面“下载现成的 ckpt 模型”步骤中准备的 labelmap.txt复制到 /examples/lite/examples/object_detection/android/app/src/main/assets/ 下记得改名,去 Android Studio 安装就可以啦……。

得到的结果就是本篇最上面的动图啦,“上帝模式” 的用户已经到终点站了,可以关掉网页啦O(∩_∩)O哈哈~。

“地图探索”:OK,苦逼模式开启,最早文档里面没有说 metadata 这一步的事情呀,所以我把 convert 之后的 tflite(没有 metadata) 放到 Android Studio 里面跑就报错啦,报错就是无法获取模型的 Metadata,所以初始化检测器失败。

然后,只能一步步看报错,一步步去找各种其他相关的文档介绍。有点长,我还是另起一篇吧,放在这里篇幅太恐怖 liao~

最最后的回顾

还有哪些工作或者功能没有提到?

  • 第一就是基于本篇里面的 ssd mobilenet v2 怎么换私有数据集训练,检测特定的物体呢?关于这点,重开一篇
  • 第二就是能不能不用 ssd mobilenet v2 啦?tf 不是还提供了很多其他的模型吗,可以换了再走一遍本篇流程拿来用呢?关于这点,(⊙o⊙)…不是很有时间做啦,答案肯定是可以的啦,溜了溜了~
  • 4
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值