TensorFlow Object Detection API 上提供了使用SSD部署到TFLite运行上去的方法, 可是这套API封装太死板, 如果你要自己实现了一套SSD的训练算法,应该怎么才能部署到TFLite上呢?
首先,抛开后处理的部分,你的SSD模型(无论是VGG-SSD和Mobilenet-SSD), 你最终的模型的输出是对class_predictions和bbox_predictions; 并且是encoded的
Encoding的方式:
class_predictions: M个Feature Layer, Feature Layer的大小(宽高)视网络结构而定; 每个Feature Layer有Num_Anchor_Depth_of_this_layer x Num_classes个channels
box_predictions: M个Feature Layer; 每个Feature Layer有Num_Anchor_Depth_of_this_layer x 4个channes 这4个channel分别代表dy,dx,h,w, 即bbox中心距离anchor中心坐标的偏移量和宽高
注:通常,为了平衡loss之间的大小, 不会直接编码dy,dx,w,h的原始值,而是dy/anchor_h*scale0, dx/anchor_w*scale0, log(h/anchor_h)*scale1, log(w/anchor_w)*scale1, 也就是偏移量的绝对值除anchor宽高得到相对值,然后再乘上一个scale, 经验值 scale0取5,scale1取10; 对于h,w是对得到相对值后先取log再乘以scale, h/anchor_h的范围在1附近, 取log后可以转换到0附近;所以在解码的时候需要做对应相反的变换;
在后面TFLite_Detection_PostProcess的Op实现里就有这么一段:
然后我们需要的是做的是decode出来对每个class的confidence和location的预测值
后处理
在Object Detection API的 export_tflite_ssd_graph_lib.py文件中,你可以看到,它区别与直接freeze pb的操作就在于最后替换了后处理的部分;
Plain Text
Plain Text
Bash
Basic
C
C++
C#
CSS
C++
Diff
Git
go
GraphQL
HTML
HTTP
Java
JavaScript
JSON
JSX
Kotlin
Less
Makefile
Markdown
MATLAB
Nginx
Objective-C
Pascal
Perl
PHP
PowerShell
Ruby
Protobuf
Python
R
Ruby
Scala
Shell
SQL
Swift
TypeScript
VB.net
XML
YAML
KaTeX
Mermaid
PlantUML
Flow
Graphviz
frozen_graph_def = exporter.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().a