tensorflow object detect API ssd-mobilenet-v1网络结构源码修改

最近使用TensorFlow object_detect API做目标检测任务,由于要求目标检测模型能够移植客户端中,进而选择目标检测模型时则选择轻量级的模型,最后选择了ssd_mobilenet_v1作为目标检测的模型。

之前写过了TensorFlow object_detect API训练自己数据的步骤以及通过修改配置文件参数降低模型输入大小和模型通道数方法达到压缩模型提高推理速率(此处)。本篇文章主要讲述如何通过修改object_detect API中网络结构源码,以ssd_mobilenet_v1为例。

  • 1.模型结构定义源码

若要在源码中修改mobilenet-ssd-v1结构,可修改的源码文件:

    (1). models-master/research/slim/nets/mobilenet_v1.py   #模型网络的主结构

    (2). models-master/research/object_detection/models/ssd_mobilenet_v1_feature_extractor.py     #anchors特征提取层

mobilenet_v1.py:

 

mobilenet_v1.py主要定义ssd_mobilenet_v1网络的主结构层,若需要对网络进行删减可修改上图中MOBILENET_CONV_DEFS变量。(若进行网络层删减后配置文件中的层数也需要进行相应修改)

ssd_mobilenet_v1_feature_extractor.py:

 

ssd_mobilenet_v1_feature_extractor.py为anchors特征提取层,anchors特征提取对应的层修改可以修改上图中的feature_map_layout变量。(anchors特征层删减后配置文件中的相关参数也需要进行修改)

配置文件参数如下:

 

  • 2. anchors特征提取层修改实践

接下来举例如何通过修改源码文件ssd_mobilenet_v1_feature_extractor.py将anchors提取的6层结构改为4层。

首先修改源码文件

models-master/research/object_detection/models/ssd_mobilenet_v1_feature_extractor.py,将feature_map_layout变量中的from_layer和layer_depth后2层去掉如下:

 

然后修改配置文件如下:

将num_layers由6改为4,并去掉2个aspect_ratios

 

这样便可以开始目标检测模型训练(由于网络结构修改了,预加载官方模型的参数需要在配置文件中去掉,改用不预加载模型的方法训练)

 

  • 3. 网络结构修改前后对比图

修改前网络结构(部分结构):

 

修改后网络结构(部分结构):

 

    可从修改前后的两个网络结构图中看出通过源码修改后,目标检测最后6层的特征提取成功变为4层,总体上能够减少一些计算量(后面的层计算量很少)。4个anchors特征提取层的目标检测模型依旧能够正常检测(精度会有相应损失)。

基于TensorFlow Object Detection API搭建自己的物体识别模型的代码如下: 1. 准备工作: - 安装TensorFlow Object Detection API - 准备训练和测试数据集 - 下载预训练的模型权重 2. 导入所需库: ```python import tensorflow as tf from object_detection.utils import dataset_util from object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util ``` 3. 加载label map和模型: ```python PATH_TO_LABELS = 'path_to_label_map.pbtxt' PATH_TO_MODEL = 'path_to_pretrained_model' label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') ``` 4. 定义函数进行物体识别: ```python def detect_objects(image): with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') image_expanded = np.expand_dims(image, axis=0) (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_expanded}) vis_util.visualize_boxes_and_labels_on_image_array( image, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) return image ``` 5. 加载测试图像并进行物体识别: ```python image = cv2.imread('test_image.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) output_image = detect_objects(image) cv2.imshow('Object Detection', output_image) cv2.waitKey(0) cv2.destroyAllWindows() ``` 通过以上代码,可以使用自己的训练数据集、预训练模型权重和标签映射文件来搭建自己的物体识别模型。设置好路径并加载模型后,将待识别的图像传入`detect_objects`函数即可返回识别结果,并在图像上进行可视化展示。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值